RA.Aid/tests/ra_aid/database/test_connection.py

196 lines
6.2 KiB
Python

"""
Tests for the database connection module.
"""
import os
from pathlib import Path
from unittest.mock import patch
import peewee
import pytest
from ra_aid.database.connection import (
DatabaseManager,
close_db,
db_var,
get_db,
init_db,
)
@pytest.fixture
def cleanup_db():
"""
Fixture to clean up database connections and files between tests.
This fixture:
1. Closes any open database connection
2. Resets the contextvar
3. Cleans up the .ra-aid directory
"""
# Run the test
yield
# Clean up after the test
try:
# Close any open database connection
close_db()
# Reset the contextvar
db_var.set(None)
# Clean up the .ra-aid directory if it exists
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
if ra_aid_dir.exists():
# Only remove the database file, not the entire directory
db_file = ra_aid_dir / "pk.db"
if db_file.exists():
db_file.unlink()
# Remove WAL and SHM files if they exist
wal_file = ra_aid_dir / "pk.db-wal"
if wal_file.exists():
wal_file.unlink()
shm_file = ra_aid_dir / "pk.db-shm"
if shm_file.exists():
shm_file.unlink()
except Exception as e:
# Log but don't fail if cleanup has issues
print(f"Cleanup error (non-fatal): {str(e)}")
@pytest.fixture
def mock_logger():
"""Mock the logger to test for output messages."""
with patch("ra_aid.database.connection.logger") as mock:
yield mock
class TestInitDb:
"""Tests for the init_db function."""
def test_init_db_default(self, cleanup_db):
"""Test init_db with default parameters."""
db = init_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
assert ra_aid_dir.exists()
assert (ra_aid_dir / "pk.db").exists()
def test_init_db_in_memory(self, cleanup_db):
"""Test init_db with in_memory=True."""
db = init_db(in_memory=True)
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
def test_init_db_reuses_connection(self, cleanup_db):
"""Test that init_db reuses an existing connection."""
db1 = init_db()
db2 = init_db()
assert db1 is db2
def test_init_db_reopens_closed_connection(self, cleanup_db):
"""Test that init_db reopens a closed connection."""
db1 = init_db()
db1.close()
assert db1.is_closed()
db2 = init_db()
assert db1 is db2
assert not db1.is_closed()
class TestGetDb:
"""Tests for the get_db function."""
def test_get_db_creates_connection(self, cleanup_db):
"""Test that get_db creates a new connection if none exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
db = get_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
def test_get_db_reuses_connection(self, cleanup_db):
"""Test that get_db reuses an existing connection."""
db1 = init_db()
db2 = get_db()
assert db1 is db2
def test_get_db_reopens_closed_connection(self, cleanup_db):
"""Test that get_db reopens a closed connection."""
db1 = init_db()
db1.close()
assert db1.is_closed()
db2 = get_db()
assert db1 is db2
assert not db1.is_closed()
class TestCloseDb:
"""Tests for the close_db function."""
def test_close_db(self, cleanup_db):
"""Test that close_db closes an open connection."""
db = init_db()
assert not db.is_closed()
close_db()
assert db.is_closed()
def test_close_db_no_connection(self, cleanup_db):
"""Test that close_db handles the case where no connection exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
# This should not raise an exception
close_db()
def test_close_db_already_closed(self, cleanup_db):
"""Test that close_db handles the case where the connection is already closed."""
db = init_db()
db.close()
assert db.is_closed()
# This should not raise an exception
close_db()
class TestDatabaseManager:
"""Tests for the DatabaseManager class."""
def test_database_manager_default(self, cleanup_db):
"""Test DatabaseManager with default parameters."""
with DatabaseManager() as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
assert ra_aid_dir.exists()
assert (ra_aid_dir / "pk.db").exists()
# Verify the connection is closed after exiting the context
assert db.is_closed()
def test_database_manager_in_memory(self, cleanup_db):
"""Test DatabaseManager with in_memory=True."""
with DatabaseManager(in_memory=True) as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
# Verify the connection is closed after exiting the context
assert db.is_closed()
def test_database_manager_exception_handling(self, cleanup_db):
"""Test that DatabaseManager properly handles exceptions."""
try:
with DatabaseManager() as db:
assert not db.is_closed()
raise ValueError("Test exception")
except ValueError:
# The exception should be propagated
pass
# Verify the connection is closed even if an exception occurred
assert db.is_closed()