RA.Aid/tests/ra_aid/database/test_connection.py

182 lines
6.2 KiB
Python

"""
Tests for the database connection module.
"""
import os
import shutil
from pathlib import Path
import pytest
import peewee
from unittest.mock import patch, MagicMock
from ra_aid.database.connection import (
init_db, get_db, close_db,
db_var, DatabaseManager, logger
)
@pytest.fixture
def cleanup_db():
"""
Fixture to clean up database connections and files between tests.
This fixture:
1. Closes any open database connection
2. Resets the contextvar
3. Cleans up the .ra-aid directory
"""
# Run the test
yield
# Clean up after the test
try:
# Close any open database connection
close_db()
# Reset the contextvar
db_var.set(None)
# Clean up the .ra-aid directory if it exists
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
if ra_aid_dir.exists():
# Only remove the database file, not the entire directory
db_file = ra_aid_dir / "pk.db"
if db_file.exists():
db_file.unlink()
# Remove WAL and SHM files if they exist
wal_file = ra_aid_dir / "pk.db-wal"
if wal_file.exists():
wal_file.unlink()
shm_file = ra_aid_dir / "pk.db-shm"
if shm_file.exists():
shm_file.unlink()
except Exception as e:
# Log but don't fail if cleanup has issues
print(f"Cleanup error (non-fatal): {str(e)}")
@pytest.fixture
def mock_logger():
"""Mock the logger to test for output messages."""
with patch('ra_aid.database.connection.logger') as mock:
yield mock
class TestInitDb:
"""Tests for the init_db function."""
def test_init_db_default(self, cleanup_db):
"""Test init_db with default parameters."""
db = init_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
assert ra_aid_dir.exists()
assert (ra_aid_dir / "pk.db").exists()
def test_init_db_in_memory(self, cleanup_db):
"""Test init_db with in_memory=True."""
db = init_db(in_memory=True)
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is True
def test_init_db_reuses_connection(self, cleanup_db):
"""Test that init_db reuses an existing connection."""
db1 = init_db()
db2 = init_db()
assert db1 is db2
def test_init_db_reopens_closed_connection(self, cleanup_db):
"""Test that init_db reopens a closed connection."""
db1 = init_db()
db1.close()
assert db1.is_closed()
db2 = init_db()
assert db1 is db2
assert not db1.is_closed()
class TestGetDb:
"""Tests for the get_db function."""
def test_get_db_creates_connection(self, cleanup_db):
"""Test that get_db creates a new connection if none exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
db = get_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is False
def test_get_db_reuses_connection(self, cleanup_db):
"""Test that get_db reuses an existing connection."""
db1 = init_db()
db2 = get_db()
assert db1 is db2
def test_get_db_reopens_closed_connection(self, cleanup_db):
"""Test that get_db reopens a closed connection."""
db1 = init_db()
db1.close()
assert db1.is_closed()
db2 = get_db()
assert db1 is db2
assert not db1.is_closed()
class TestCloseDb:
"""Tests for the close_db function."""
def test_close_db(self, cleanup_db):
"""Test that close_db closes an open connection."""
db = init_db()
assert not db.is_closed()
close_db()
assert db.is_closed()
def test_close_db_no_connection(self, cleanup_db):
"""Test that close_db handles the case where no connection exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
# This should not raise an exception
close_db()
def test_close_db_already_closed(self, cleanup_db):
"""Test that close_db handles the case where the connection is already closed."""
db = init_db()
db.close()
assert db.is_closed()
# This should not raise an exception
close_db()
class TestDatabaseManager:
"""Tests for the DatabaseManager class."""
def test_database_manager_default(self, cleanup_db):
"""Test DatabaseManager with default parameters."""
with DatabaseManager() as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
assert ra_aid_dir.exists()
assert (ra_aid_dir / "pk.db").exists()
# Verify the connection is closed after exiting the context
assert db.is_closed()
def test_database_manager_in_memory(self, cleanup_db):
"""Test DatabaseManager with in_memory=True."""
with DatabaseManager(in_memory=True) as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is True
# Verify the connection is closed after exiting the context
assert db.is_closed()
def test_database_manager_exception_handling(self, cleanup_db):
"""Test that DatabaseManager properly handles exceptions."""
try:
with DatabaseManager() as db:
assert not db.is_closed()
raise ValueError("Test exception")
except ValueError:
# The exception should be propagated
pass
# Verify the connection is closed even if an exception occurred
assert db.is_closed()