RA.Aid/ra_aid/database/utils.py

99 lines
3.0 KiB
Python

"""
Database utility functions for ra_aid.
This module provides utility functions for common database operations.
"""
import inspect
from typing import List, Type
import peewee
from ra_aid.database.connection import get_db
from ra_aid.database.models import BaseModel, initialize_database
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
"""
Ensure that database tables for the specified models exist.
If no models are specified, this function will attempt to discover
all models that inherit from BaseModel.
Args:
models: Optional list of model classes to create tables for
"""
db = initialize_database()
if models is None:
# If no models are specified, try to discover them
models = []
try:
# Import all modules that might contain models
# This is a placeholder - in a real implementation, you would
# dynamically discover and import all modules that might contain models
from ra_aid.database import models as models_module
# Find all classes in the module that inherit from BaseModel
for name, obj in inspect.getmembers(models_module):
if (
inspect.isclass(obj)
and issubclass(obj, BaseModel)
and obj != BaseModel
):
models.append(obj)
except ImportError as e:
logger.warning(f"Error importing model modules: {e}")
if not models:
logger.warning("No models found to create tables for")
return
try:
with db.atomic():
db.create_tables(models, safe=True)
logger.info(f"Successfully created tables for {len(models)} models")
except peewee.DatabaseError as e:
logger.error(f"Database Error: Failed to create tables: {str(e)}")
raise
except Exception as e:
logger.error(f"Error: Failed to create tables: {str(e)}")
raise
def get_model_count(model_class: Type[BaseModel]) -> int:
"""
Get the count of records for a specific model.
Args:
model_class: The model class to count records for
Returns:
int: The number of records for the model
"""
try:
return model_class.select().count()
except peewee.DatabaseError as e:
logger.error(f"Database Error: Failed to count records: {str(e)}")
return 0
def truncate_table(model_class: Type[BaseModel]) -> None:
"""
Delete all records from a model's table.
Args:
model_class: The model class to truncate
"""
db = initialize_database()
try:
with db.atomic():
model_class.delete().execute()
logger.info(f"Successfully truncated table for {model_class.__name__}")
except peewee.DatabaseError as e:
logger.error(f"Database Error: Failed to truncate table: {str(e)}")
raise