99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
"""
|
|
Database utility functions for ra_aid.
|
|
|
|
This module provides utility functions for common database operations.
|
|
"""
|
|
|
|
import inspect
|
|
from typing import List, Type
|
|
|
|
import peewee
|
|
|
|
from ra_aid.database.connection import get_db
|
|
from ra_aid.database.models import BaseModel, initialize_database
|
|
from ra_aid.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
|
|
"""
|
|
Ensure that database tables for the specified models exist.
|
|
|
|
If no models are specified, this function will attempt to discover
|
|
all models that inherit from BaseModel.
|
|
|
|
Args:
|
|
models: Optional list of model classes to create tables for
|
|
"""
|
|
db = initialize_database()
|
|
|
|
if models is None:
|
|
# If no models are specified, try to discover them
|
|
models = []
|
|
try:
|
|
# Import all modules that might contain models
|
|
# This is a placeholder - in a real implementation, you would
|
|
# dynamically discover and import all modules that might contain models
|
|
from ra_aid.database import models as models_module
|
|
|
|
# Find all classes in the module that inherit from BaseModel
|
|
for name, obj in inspect.getmembers(models_module):
|
|
if (
|
|
inspect.isclass(obj)
|
|
and issubclass(obj, BaseModel)
|
|
and obj != BaseModel
|
|
):
|
|
models.append(obj)
|
|
except ImportError as e:
|
|
logger.warning(f"Error importing model modules: {e}")
|
|
|
|
if not models:
|
|
logger.warning("No models found to create tables for")
|
|
return
|
|
|
|
try:
|
|
with db.atomic():
|
|
db.create_tables(models, safe=True)
|
|
logger.info(f"Successfully created tables for {len(models)} models")
|
|
except peewee.DatabaseError as e:
|
|
logger.error(f"Database Error: Failed to create tables: {str(e)}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error: Failed to create tables: {str(e)}")
|
|
raise
|
|
|
|
|
|
def get_model_count(model_class: Type[BaseModel]) -> int:
|
|
"""
|
|
Get the count of records for a specific model.
|
|
|
|
Args:
|
|
model_class: The model class to count records for
|
|
|
|
Returns:
|
|
int: The number of records for the model
|
|
"""
|
|
try:
|
|
return model_class.select().count()
|
|
except peewee.DatabaseError as e:
|
|
logger.error(f"Database Error: Failed to count records: {str(e)}")
|
|
return 0
|
|
|
|
|
|
def truncate_table(model_class: Type[BaseModel]) -> None:
|
|
"""
|
|
Delete all records from a model's table.
|
|
|
|
Args:
|
|
model_class: The model class to truncate
|
|
"""
|
|
db = initialize_database()
|
|
try:
|
|
with db.atomic():
|
|
model_class.delete().execute()
|
|
logger.info(f"Successfully truncated table for {model_class.__name__}")
|
|
except peewee.DatabaseError as e:
|
|
logger.error(f"Database Error: Failed to truncate table: {str(e)}")
|
|
raise
|