RA.Aid/tests/ra_aid/database/test_utils.py

217 lines
6.4 KiB
Python

"""
Tests for the database utils module.
"""
from unittest.mock import MagicMock, patch
import peewee
import pytest
from ra_aid.database.connection import db_var, init_db
from ra_aid.database.models import BaseModel
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
@pytest.fixture
def cleanup_db():
"""Reset the database contextvar and connection state after each test."""
# Reset before the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
@pytest.fixture
def mock_logger():
"""Mock the logger to test for output messages."""
with patch("ra_aid.database.utils.logger") as mock:
yield mock
@pytest.fixture
def setup_test_model(cleanup_db):
"""Set up a test model for database tests."""
# Initialize the database in memory
db = init_db(in_memory=True)
# Initialize the database proxy
from ra_aid.database.models import initialize_database
initialize_database()
# Define a test model class
class TestModel(BaseModel):
name = peewee.CharField(max_length=100)
value = peewee.IntegerField(default=0)
# Create the test table in a transaction
with db.atomic():
db.create_tables([TestModel], safe=True)
# Yield control to the test
yield TestModel
# Clean up: drop the test table
with db.atomic():
db.drop_tables([TestModel], safe=True)
def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
"""Test ensure_tables_created with explicit models."""
# Initialize the database in memory
db = init_db(in_memory=True)
# Initialize the database proxy
from ra_aid.database.models import initialize_database
initialize_database()
# Define a test model that uses the proxy database
class TestModel(BaseModel):
name = peewee.CharField(max_length=100)
value = peewee.IntegerField(default=0)
# Call ensure_tables_created with explicit models
ensure_tables_created([TestModel])
# Verify success message was logged
mock_logger.info.assert_called_with("Successfully created tables for 1 models")
# Verify the table exists by trying to use it
TestModel.create(name="test", value=42)
count = TestModel.select().count()
assert count == 1
@patch("ra_aid.database.utils.initialize_database")
def test_ensure_tables_created_database_error(
mock_initialize_database, setup_test_model, cleanup_db, mock_logger
):
"""Test ensure_tables_created handles database errors."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Create a mock database with a create_tables method that raises an error
mock_db = MagicMock()
mock_db.atomic.return_value.__enter__.return_value = None
mock_db.atomic.return_value.__exit__.return_value = None
mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error")
# Configure initialize_database to return our mock
mock_initialize_database.return_value = mock_db
# Call ensure_tables_created and expect an exception
with pytest.raises(peewee.DatabaseError):
ensure_tables_created([TestModel])
# Verify error message was logged
mock_logger.error.assert_called_with(
"Database Error: Failed to create tables: Test database error"
)
def test_get_model_count(setup_test_model, mock_logger):
"""Test get_model_count returns the correct count."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# First ensure the table is empty
TestModel.delete().execute()
# Create some test records
TestModel.create(name="test1", value=1)
TestModel.create(name="test2", value=2)
# Call get_model_count
count = get_model_count(TestModel)
# Verify the count is correct
assert count == 2
@patch("peewee.ModelSelect.count")
def test_get_model_count_database_error(mock_count, setup_test_model, mock_logger):
"""Test get_model_count handles database errors."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Configure the mock to raise a DatabaseError
mock_count.side_effect = peewee.DatabaseError("Test count error")
# Call get_model_count
count = get_model_count(TestModel)
# Verify error message was logged
mock_logger.error.assert_called_with(
"Database Error: Failed to count records: Test count error"
)
# Verify the function returns 0 on error
assert count == 0
def test_truncate_table(setup_test_model, mock_logger):
"""Test truncate_table deletes all records."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Create some test records
TestModel.create(name="test1", value=1)
TestModel.create(name="test2", value=2)
# Verify records exist
assert TestModel.select().count() == 2
# Call truncate_table
truncate_table(TestModel)
# Verify success message was logged
mock_logger.info.assert_called_with(
f"Successfully truncated table for {TestModel.__name__}"
)
# Verify all records were deleted
assert TestModel.select().count() == 0
@patch("ra_aid.database.models.BaseModel.delete")
def test_truncate_table_database_error(mock_delete, setup_test_model, mock_logger):
"""Test truncate_table handles database errors."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Create a test record
TestModel.create(name="test", value=42)
# Configure the mock to return a mock query with execute that raises an error
mock_query = MagicMock()
mock_query.execute.side_effect = peewee.DatabaseError("Test delete error")
mock_delete.return_value = mock_query
# Call truncate_table and expect an exception
with pytest.raises(peewee.DatabaseError):
truncate_table(TestModel)
# Verify error message was logged
mock_logger.error.assert_called_with(
"Database Error: Failed to truncate table: Test delete error"
)