base db infra

This commit is contained in:
AI Christianson 2025-02-26 11:44:18 -05:00
parent f05d30ff50
commit dbf4d954e1
12 changed files with 1906 additions and 162 deletions

View File

@ -17,6 +17,7 @@ from ra_aid.agent_utils import (
run_planning_agent,
run_research_agent,
)
from ra_aid.database import init_db, close_db, DatabaseManager
from ra_aid.config import (
DEFAULT_MAX_TEST_CMD_RETRIES,
DEFAULT_RECURSION_LIMIT,
@ -333,201 +334,202 @@ def main():
return
try:
# Check dependencies before proceeding
check_dependencies()
with DatabaseManager() as db:
# Check dependencies before proceeding
check_dependencies()
expert_enabled, expert_missing, web_research_enabled, web_research_missing = (
validate_environment(args)
) # Will exit if main env vars missing
logger.debug("Environment validation successful")
expert_enabled, expert_missing, web_research_enabled, web_research_missing = (
validate_environment(args)
) # Will exit if main env vars missing
logger.debug("Environment validation successful")
# Validate model configuration early
# Validate model configuration early
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
supports_temperature = model_config.get(
"supports_temperature",
args.provider
in ["anthropic", "openai", "openrouter", "openai-compatible", "deepseek"],
)
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
supports_temperature = model_config.get(
"supports_temperature",
args.provider
in ["anthropic", "openai", "openrouter", "openai-compatible", "deepseek"],
)
if supports_temperature and args.temperature is None:
args.temperature = model_config.get("default_temperature")
if args.temperature is None:
cpm(
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
if supports_temperature and args.temperature is None:
args.temperature = model_config.get("default_temperature")
if args.temperature is None:
cpm(
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
)
args.temperature = DEFAULT_TEMPERATURE
logger.debug(
f"Using default temperature {args.temperature} for model {args.model}"
)
args.temperature = DEFAULT_TEMPERATURE
logger.debug(
f"Using default temperature {args.temperature} for model {args.model}"
status = build_status(args, expert_enabled, web_research_enabled)
console.print(
Panel(status, title=f"RA.Aid v{__version__}", border_style="bright_blue", padding=(0, 1))
)
status = build_status(args, expert_enabled, web_research_enabled)
# Handle chat mode
if args.chat:
# Initialize chat model with default provider/model
chat_model = initialize_llm(
args.provider, args.model, temperature=args.temperature
)
if args.research_only:
print_error("Chat mode cannot be used with --research-only")
sys.exit(1)
console.print(
Panel(status, title=f"RA.Aid v{__version__}", border_style="bright_blue", padding=(0, 1))
)
print_stage_header("Chat Mode")
# Handle chat mode
if args.chat:
# Initialize chat model with default provider/model
chat_model = initialize_llm(
args.provider, args.model, temperature=args.temperature
)
if args.research_only:
print_error("Chat mode cannot be used with --research-only")
# Get project info
try:
project_info = get_project_info(".", file_limit=2000)
formatted_project_info = format_project_info(project_info)
except Exception as e:
logger.warning(f"Failed to get project info: {e}")
formatted_project_info = ""
# Get initial request from user
initial_request = ask_human.invoke(
{"question": "What would you like help with?"}
)
# Get working directory and current date
working_directory = os.getcwd()
current_date = datetime.now().strftime("%Y-%m-%d")
# Run chat agent with CHAT_PROMPT
config = {
"configurable": {"thread_id": str(uuid.uuid4())},
"recursion_limit": args.recursion_limit,
"chat_mode": True,
"cowboy_mode": args.cowboy_mode,
"hil": True, # Always true in chat mode
"web_research_enabled": web_research_enabled,
"initial_request": initial_request,
"limit_tokens": args.disable_limit_tokens,
}
# Store config in global memory
_global_memory["config"] = config
_global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
_global_memory["config"]["temperature"] = args.temperature
# Create chat agent with appropriate tools
chat_agent = create_agent(
chat_model,
get_chat_tools(
expert_enabled=expert_enabled,
web_research_enabled=web_research_enabled,
),
checkpointer=MemorySaver(),
)
# Run chat agent and exit
run_agent_with_retry(
chat_agent,
CHAT_PROMPT.format(
initial_request=initial_request,
web_research_section=(
WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else ""
),
working_directory=working_directory,
current_date=current_date,
project_info=formatted_project_info,
),
config,
)
return
# Validate message is provided
if not args.message:
print_error("--message is required")
sys.exit(1)
print_stage_header("Chat Mode")
# Get project info
try:
project_info = get_project_info(".", file_limit=2000)
formatted_project_info = format_project_info(project_info)
except Exception as e:
logger.warning(f"Failed to get project info: {e}")
formatted_project_info = ""
# Get initial request from user
initial_request = ask_human.invoke(
{"question": "What would you like help with?"}
)
# Get working directory and current date
working_directory = os.getcwd()
current_date = datetime.now().strftime("%Y-%m-%d")
# Run chat agent with CHAT_PROMPT
base_task = args.message
config = {
"configurable": {"thread_id": str(uuid.uuid4())},
"recursion_limit": args.recursion_limit,
"chat_mode": True,
"research_only": args.research_only,
"cowboy_mode": args.cowboy_mode,
"hil": True, # Always true in chat mode
"web_research_enabled": web_research_enabled,
"initial_request": initial_request,
"aider_config": args.aider_config,
"limit_tokens": args.disable_limit_tokens,
"auto_test": args.auto_test,
"test_cmd": args.test_cmd,
"max_test_cmd_retries": args.max_test_cmd_retries,
"experimental_fallback_handler": args.experimental_fallback_handler,
"test_cmd_timeout": args.test_cmd_timeout,
}
# Store config in global memory
# Store config in global memory for access by is_informational_query
_global_memory["config"] = config
# Store base provider/model configuration
_global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model
# Store expert provider/model (no fallback)
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
# Store planner config with fallback to base values
_global_memory["config"]["planner_provider"] = (
args.planner_provider or args.provider
)
_global_memory["config"]["planner_model"] = args.planner_model or args.model
# Store research config with fallback to base values
_global_memory["config"]["research_provider"] = (
args.research_provider or args.provider
)
_global_memory["config"]["research_model"] = args.research_model or args.model
# Store temperature in global config
_global_memory["config"]["temperature"] = args.temperature
# Create chat agent with appropriate tools
chat_agent = create_agent(
chat_model,
get_chat_tools(
expert_enabled=expert_enabled,
web_research_enabled=web_research_enabled,
),
checkpointer=MemorySaver(),
# Run research stage
print_stage_header("Research Stage")
# Initialize research model with potential overrides
research_provider = args.research_provider or args.provider
research_model_name = args.research_model or args.model
research_model = initialize_llm(
research_provider, research_model_name, temperature=args.temperature
)
# Run chat agent and exit
run_agent_with_retry(
chat_agent,
CHAT_PROMPT.format(
initial_request=initial_request,
web_research_section=(
WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else ""
),
working_directory=working_directory,
current_date=current_date,
project_info=formatted_project_info,
),
config,
)
return
# Validate message is provided
if not args.message:
print_error("--message is required")
sys.exit(1)
base_task = args.message
config = {
"configurable": {"thread_id": str(uuid.uuid4())},
"recursion_limit": args.recursion_limit,
"research_only": args.research_only,
"cowboy_mode": args.cowboy_mode,
"web_research_enabled": web_research_enabled,
"aider_config": args.aider_config,
"limit_tokens": args.disable_limit_tokens,
"auto_test": args.auto_test,
"test_cmd": args.test_cmd,
"max_test_cmd_retries": args.max_test_cmd_retries,
"experimental_fallback_handler": args.experimental_fallback_handler,
"test_cmd_timeout": args.test_cmd_timeout,
}
# Store config in global memory for access by is_informational_query
_global_memory["config"] = config
# Store base provider/model configuration
_global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model
# Store expert provider/model (no fallback)
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
# Store planner config with fallback to base values
_global_memory["config"]["planner_provider"] = (
args.planner_provider or args.provider
)
_global_memory["config"]["planner_model"] = args.planner_model or args.model
# Store research config with fallback to base values
_global_memory["config"]["research_provider"] = (
args.research_provider or args.provider
)
_global_memory["config"]["research_model"] = args.research_model or args.model
# Store temperature in global config
_global_memory["config"]["temperature"] = args.temperature
# Run research stage
print_stage_header("Research Stage")
# Initialize research model with potential overrides
research_provider = args.research_provider or args.provider
research_model_name = args.research_model or args.model
research_model = initialize_llm(
research_provider, research_model_name, temperature=args.temperature
)
run_research_agent(
base_task,
research_model,
expert_enabled=expert_enabled,
research_only=args.research_only,
hil=args.hil,
memory=research_memory,
config=config,
)
# Proceed with planning and implementation if not an informational query
if not is_informational_query():
# Initialize planning model with potential overrides
planner_provider = args.planner_provider or args.provider
planner_model_name = args.planner_model or args.model
planning_model = initialize_llm(
planner_provider, planner_model_name, temperature=args.temperature
)
# Run planning agent
run_planning_agent(
run_research_agent(
base_task,
planning_model,
research_model,
expert_enabled=expert_enabled,
research_only=args.research_only,
hil=args.hil,
memory=planning_memory,
memory=research_memory,
config=config,
)
# Proceed with planning and implementation if not an informational query
if not is_informational_query():
# Initialize planning model with potential overrides
planner_provider = args.planner_provider or args.provider
planner_model_name = args.planner_model or args.model
planning_model = initialize_llm(
planner_provider, planner_model_name, temperature=args.temperature
)
# Run planning agent
run_planning_agent(
base_task,
planning_model,
expert_enabled=expert_enabled,
hil=args.hil,
memory=planning_memory,
config=config,
)
except (KeyboardInterrupt, AgentInterrupt):
print()
print(" 👋 Bye!")

View File

@ -0,0 +1,26 @@
"""
Database package for ra_aid.
This package provides database functionality for the ra_aid application,
including connection management, models, and utility functions.
"""
from ra_aid.database.connection import (
init_db,
get_db,
close_db,
DatabaseManager
)
from ra_aid.database.models import BaseModel
from ra_aid.database.utils import get_model_count, truncate_table, ensure_tables_created
__all__ = [
'init_db',
'get_db',
'close_db',
'DatabaseManager',
'BaseModel',
'get_model_count',
'truncate_table',
'ensure_tables_created',
]

View File

@ -0,0 +1,371 @@
"""
Database connection management for ra_aid.
This module provides functions to initialize, get, and close database connections.
It also provides a context manager for database connections.
"""
import os
import contextvars
from pathlib import Path
from typing import Optional, Any
import peewee
from ra_aid.logging_config import get_logger
# Create contextvar to hold the database connection
db_var = contextvars.ContextVar("db", default=None)
logger = get_logger(__name__)
class DatabaseManager:
"""
Context manager for database connections.
This class provides a context manager interface for database connections,
using the existing contextvars approach internally.
Example:
with DatabaseManager() as db:
# Use the database connection
db.execute_sql("SELECT * FROM table")
# Or with in-memory database:
with DatabaseManager(in_memory=True) as db:
# Use in-memory database
"""
def __init__(self, in_memory: bool = False):
"""
Initialize the DatabaseManager.
Args:
in_memory: Whether to use an in-memory database (default: False)
"""
self.in_memory = in_memory
def __enter__(self) -> peewee.SqliteDatabase:
"""
Initialize the database connection and return it.
Returns:
peewee.SqliteDatabase: The initialized database connection
"""
return init_db(in_memory=self.in_memory)
def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception],
exc_tb: Optional[Any]) -> None:
"""
Close the database connection when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
close_db()
# Don't suppress exceptions
return False
def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
"""
Initialize the database connection.
Creates the .ra-aid directory if it doesn't exist and initializes
the SQLite database connection. If a database connection already exists,
returns the existing connection instead of creating a new one.
Args:
in_memory: Whether to use an in-memory database (default: False)
Returns:
peewee.SqliteDatabase: The initialized database connection
"""
# Check if a database connection already exists
existing_db = db_var.get()
if existing_db is not None:
# If the connection exists but is closed, reopen it
if existing_db.is_closed():
try:
existing_db.connect()
except peewee.OperationalError as e:
logger.error(f"Failed to reopen existing database connection: {str(e)}")
# Continue to create a new connection if reopening fails
else:
return existing_db
else:
# Connection exists and is open, return it
return existing_db
# Set up database path
if in_memory:
# Use in-memory database
db_path = ":memory:"
logger.debug("Using in-memory SQLite database")
else:
# Get current working directory and create .ra-aid directory if it doesn't exist
cwd = os.getcwd()
logger.debug(f"Current working directory: {cwd}")
print(f"DIRECT DEBUG: Current working directory: {cwd}")
# Define the .ra-aid directory path
ra_aid_dir_str = os.path.join(cwd, ".ra-aid")
ra_aid_dir = Path(ra_aid_dir_str)
ra_aid_dir = ra_aid_dir.absolute() # Ensure we have the absolute path
ra_aid_dir_str = str(ra_aid_dir) # Update string representation with absolute path
logger.debug(f"Creating database directory at: {ra_aid_dir_str}")
print(f"DIRECT DEBUG: Creating database directory at: {ra_aid_dir_str}")
# Multiple approaches to ensure directory creation
directory_created = False
error_messages = []
# Approach 1: Try os.mkdir directly
if not os.path.exists(ra_aid_dir_str):
try:
logger.debug("Attempting directory creation with os.mkdir")
print(f"DIRECT DEBUG: Attempting directory creation with os.mkdir: {ra_aid_dir_str}")
os.mkdir(ra_aid_dir_str, mode=0o755)
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
if directory_created:
logger.debug("Directory created successfully with os.mkdir")
print(f"DIRECT DEBUG: Directory created successfully with os.mkdir")
except Exception as e:
error_msg = f"os.mkdir failed: {str(e)}"
logger.debug(error_msg)
print(f"DIRECT DEBUG: {error_msg}")
error_messages.append(error_msg)
else:
logger.debug("Directory already exists, skipping creation")
print(f"DIRECT DEBUG: Directory already exists at {ra_aid_dir_str}")
directory_created = True
# Approach 2: Try os.makedirs if os.mkdir failed
if not directory_created:
try:
logger.debug("Attempting directory creation with os.makedirs")
print(f"DIRECT DEBUG: Attempting directory creation with os.makedirs: {ra_aid_dir_str}")
os.makedirs(ra_aid_dir_str, exist_ok=True, mode=0o755)
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
if directory_created:
logger.debug("Directory created successfully with os.makedirs")
print(f"DIRECT DEBUG: Directory created successfully with os.makedirs")
except Exception as e:
error_msg = f"os.makedirs failed: {str(e)}"
logger.debug(error_msg)
print(f"DIRECT DEBUG: {error_msg}")
error_messages.append(error_msg)
# Approach 3: Try Path.mkdir if previous methods failed
if not directory_created:
try:
logger.debug("Attempting directory creation with Path.mkdir")
print(f"DIRECT DEBUG: Attempting directory creation with Path.mkdir: {ra_aid_dir}")
ra_aid_dir.mkdir(mode=0o755, parents=True, exist_ok=True)
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
if directory_created:
logger.debug("Directory created successfully with Path.mkdir")
print(f"DIRECT DEBUG: Directory created successfully with Path.mkdir")
except Exception as e:
error_msg = f"Path.mkdir failed: {str(e)}"
logger.debug(error_msg)
print(f"DIRECT DEBUG: {error_msg}")
error_messages.append(error_msg)
# Verify the directory was actually created
path_exists = ra_aid_dir.exists()
os_exists = os.path.exists(ra_aid_dir_str)
is_dir = os.path.isdir(ra_aid_dir_str) if os_exists else False
logger.debug(f"Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}")
print(f"DIRECT DEBUG: Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}")
# Check parent directory permissions and contents for debugging
try:
parent_dir = os.path.dirname(ra_aid_dir_str)
parent_perms = oct(os.stat(parent_dir).st_mode)[-3:]
parent_contents = os.listdir(parent_dir)
logger.debug(f"Parent directory {parent_dir} permissions: {parent_perms}")
logger.debug(f"Parent directory contents: {parent_contents}")
print(f"DIRECT DEBUG: Parent directory {parent_dir} permissions: {parent_perms}")
print(f"DIRECT DEBUG: Parent directory contents: {parent_contents}")
except Exception as e:
logger.debug(f"Could not check parent directory: {str(e)}")
print(f"DIRECT DEBUG: Could not check parent directory: {str(e)}")
if not os_exists or not is_dir:
error_msg = f"Directory does not exist or is not a directory after creation attempts: {ra_aid_dir_str}"
logger.error(error_msg)
print(f"DIRECT DEBUG ERROR: {error_msg}")
if error_messages:
logger.error(f"Previous errors: {', '.join(error_messages)}")
print(f"DIRECT DEBUG ERROR: Previous errors: {', '.join(error_messages)}")
raise FileNotFoundError(f"Failed to create directory: {ra_aid_dir_str}")
# Check directory permissions
try:
permissions = oct(os.stat(ra_aid_dir_str).st_mode)[-3:]
logger.debug(f"Directory created/verified: {ra_aid_dir_str} with permissions {permissions}")
print(f"DIRECT DEBUG: Directory created/verified: {ra_aid_dir_str} with permissions {permissions}")
# List directory contents for debugging
dir_contents = os.listdir(ra_aid_dir_str)
logger.debug(f"Directory contents: {dir_contents}")
print(f"DIRECT DEBUG: Directory contents: {dir_contents}")
except Exception as e:
logger.debug(f"Could not check directory details: {str(e)}")
print(f"DIRECT DEBUG: Could not check directory details: {str(e)}")
# Database path for file-based database - use os.path.join for maximum compatibility
db_path = os.path.join(ra_aid_dir_str, "pk.db")
logger.debug(f"Database path: {db_path}")
print(f"DIRECT DEBUG: Database path: {db_path}")
try:
# For file-based databases, ensure the file exists or can be created
if db_path != ":memory:":
# Check if the database file exists
db_file_exists = os.path.exists(db_path)
logger.debug(f"Database file exists check: {db_file_exists}")
print(f"DIRECT DEBUG: Database file exists check: {db_file_exists}")
# If the file doesn't exist, try to create an empty file to ensure we have write permissions
if not db_file_exists:
try:
logger.debug(f"Creating empty database file at: {db_path}")
print(f"DIRECT DEBUG: Creating empty database file at: {db_path}")
with open(db_path, 'w') as f:
pass # Create empty file
# Verify the file was created
if os.path.exists(db_path):
logger.debug("Empty database file created successfully")
print(f"DIRECT DEBUG: Empty database file created successfully")
else:
logger.error(f"Failed to create database file at: {db_path}")
print(f"DIRECT DEBUG ERROR: Failed to create database file at: {db_path}")
except Exception as e:
logger.error(f"Error creating database file: {str(e)}")
print(f"DIRECT DEBUG ERROR: Error creating database file: {str(e)}")
# Continue anyway, as SQLite might be able to create the file itself
# Initialize the database connection
logger.debug(f"Initializing SQLite database at: {db_path}")
print(f"DIRECT DEBUG: Initializing SQLite database at: {db_path}")
db = peewee.SqliteDatabase(
db_path,
pragmas={
'journal_mode': 'wal', # Write-Ahead Logging for better concurrency
'foreign_keys': 1, # Enforce foreign key constraints
'cache_size': -1024 * 32, # 32MB cache
}
)
# Always explicitly connect to ensure the connection is established
if db.is_closed():
logger.debug("Explicitly connecting to database")
print(f"DIRECT DEBUG: Explicitly connecting to database")
db.connect()
# Store the database connection in the contextvar
db_var.set(db)
# Store whether this is an in-memory database (for backward compatibility)
db._is_in_memory = in_memory
# Verify the database is usable by executing a simple query
if not in_memory:
try:
db.execute_sql("SELECT 1")
logger.debug("Database connection verified with test query")
print(f"DIRECT DEBUG: Database connection verified with test query")
# Check if the database file exists after initialization
db_file_exists = os.path.exists(db_path)
db_file_size = os.path.getsize(db_path) if db_file_exists else 0
logger.debug(f"Database file check after init: exists={db_file_exists}, size={db_file_size} bytes")
print(f"DIRECT DEBUG: Database file check after init: exists={db_file_exists}, size={db_file_size} bytes")
except Exception as e:
logger.error(f"Database verification failed: {str(e)}")
print(f"DIRECT DEBUG ERROR: Database verification failed: {str(e)}")
# Continue anyway, as this is just a verification step
# Only show initialization message if it hasn't been shown before
if not hasattr(db, '_message_shown') or not db._message_shown:
if in_memory:
logger.debug("In-memory database connection initialized successfully")
else:
logger.debug("Database connection initialized successfully")
db._message_shown = True
return db
except peewee.OperationalError as e:
logger.error(f"Database Operational Error: {str(e)}")
raise
except peewee.DatabaseError as e:
logger.error(f"Database Error: {str(e)}")
raise
except Exception as e:
logger.error(f"Failed to initialize database: {str(e)}")
raise
def get_db() -> peewee.SqliteDatabase:
"""
Get the current database connection.
If no connection exists, initializes a new one.
If connection exists but is closed, reopens it.
Returns:
peewee.SqliteDatabase: The current database connection
"""
db = db_var.get()
if db is None:
# No database connection exists, initialize one
# Use the default in-memory mode (False)
return init_db(in_memory=False)
# Check if connection is closed and reopen if needed
if db.is_closed():
try:
logger.debug("Attempting to reopen closed database connection")
db.connect()
logger.info("Reopened existing database connection")
except peewee.OperationalError as e:
logger.error(f"Failed to reopen database connection: {str(e)}")
# Create a completely new connection
# First, remove the old connection from the context var
db_var.set(None)
# Then initialize a new connection with the same in-memory setting
in_memory = hasattr(db, '_is_in_memory') and db._is_in_memory
logger.debug(f"Creating new database connection (in_memory={in_memory})")
# Create a completely new database object, don't reuse the old one
return init_db(in_memory=in_memory)
return db
def close_db() -> None:
"""
Close the current database connection if it exists.
Handles various error conditions gracefully.
"""
db = db_var.get()
if db is None:
logger.warning("No database connection to close")
return
try:
if not db.is_closed():
db.close()
logger.info("Database connection closed successfully")
else:
logger.warning("Database connection was already closed")
except peewee.DatabaseError as e:
logger.error(f"Database Error: Failed to close connection: {str(e)}")
except Exception as e:
logger.error(f"Failed to close database connection: {str(e)}")

62
ra_aid/database/models.py Normal file
View File

@ -0,0 +1,62 @@
"""
Database models for ra_aid.
This module defines the base model class that all models will inherit from.
"""
import datetime
from typing import Any, Dict, Type, TypeVar
import peewee
from ra_aid.database.connection import get_db
from ra_aid.logging_config import get_logger
T = TypeVar('T', bound='BaseModel')
logger = get_logger(__name__)
class BaseModel(peewee.Model):
"""
Base model class for all ra_aid models.
All models should inherit from this class to ensure consistent
behavior and database connection.
"""
created_at = peewee.DateTimeField(default=datetime.datetime.now)
updated_at = peewee.DateTimeField(default=datetime.datetime.now)
class Meta:
database = get_db()
def save(self, *args: Any, **kwargs: Any) -> int:
"""
Override save to update the updated_at field.
Args:
*args: Arguments to pass to the parent save method
**kwargs: Keyword arguments to pass to the parent save method
Returns:
int: The primary key of the saved instance
"""
self.updated_at = datetime.datetime.now()
return super().save(*args, **kwargs)
@classmethod
def get_or_create(cls: Type[T], **kwargs: Any) -> tuple[T, bool]:
"""
Get an instance or create it if it doesn't exist.
Args:
**kwargs: Fields to use for lookup and creation
Returns:
tuple: (instance, created) where created is a boolean indicating
whether a new instance was created
"""
try:
return super().get_or_create(**kwargs)
except peewee.DatabaseError as e:
# Log the error with logger
logger.error(f"Failed in get_or_create: {str(e)}")
raise

View File

@ -0,0 +1,653 @@
"""
Tests for the database connection module.
"""
import os
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
import peewee
from ra_aid.database.connection import (
init_db, get_db, close_db, db_var, DatabaseManager
)
@pytest.fixture
def cleanup_db():
"""
Fixture to clean up database connections after tests.
"""
# Run the test
yield
# Clean up after the test
db = db_var.get()
if db is not None:
if hasattr(db, '_is_in_memory'):
delattr(db, '_is_in_memory')
if hasattr(db, '_message_shown'):
delattr(db, '_message_shown')
if not db.is_closed():
db.close()
# Reset the contextvar
db_var.set(None)
@pytest.fixture
def setup_in_memory_db():
"""
Fixture to set up an in-memory database for testing.
"""
# Initialize in-memory database
db = init_db(in_memory=True)
# Run the test
yield db
# Clean up
if not db.is_closed():
db.close()
db_var.set(None)
def test_init_db_creates_directory(cleanup_db, tmp_path):
"""
Test that init_db creates the .ra-aid directory if it doesn't exist.
"""
# Get and print the original working directory
original_cwd = os.getcwd()
print(f"Original working directory: {original_cwd}")
# Convert tmp_path to string for consistent handling
tmp_path_str = str(tmp_path.absolute())
print(f"Temporary directory path: {tmp_path_str}")
# Change to the temporary directory
os.chdir(tmp_path_str)
current_cwd = os.getcwd()
print(f"Changed working directory to: {current_cwd}")
assert current_cwd == tmp_path_str, f"Failed to change directory: {current_cwd} != {tmp_path_str}"
# Create the .ra-aid directory manually to ensure it exists
ra_aid_path_str = os.path.join(current_cwd, ".ra-aid")
print(f"Creating .ra-aid directory at: {ra_aid_path_str}")
os.makedirs(ra_aid_path_str, exist_ok=True)
# Verify the directory was created
assert os.path.exists(ra_aid_path_str), f".ra-aid directory not found at {ra_aid_path_str}"
assert os.path.isdir(ra_aid_path_str), f"{ra_aid_path_str} exists but is not a directory"
# Create a test file to verify write permissions
test_file_path = os.path.join(ra_aid_path_str, "test_write.txt")
print(f"Creating test file to verify write permissions: {test_file_path}")
with open(test_file_path, 'w') as f:
f.write("Test write permissions")
# Verify the test file was created
assert os.path.exists(test_file_path), f"Test file not created at {test_file_path}"
# Create an empty database file to ensure it exists before init_db
db_file_str = os.path.join(ra_aid_path_str, "pk.db")
print(f"Creating empty database file at: {db_file_str}")
with open(db_file_str, 'w') as f:
f.write("") # Create empty file
# Verify the database file was created
assert os.path.exists(db_file_str), f"Empty database file not created at {db_file_str}"
print(f"Empty database file size: {os.path.getsize(db_file_str)} bytes")
# Get directory permissions for debugging
dir_perms = oct(os.stat(ra_aid_path_str).st_mode)[-3:]
print(f"Directory permissions: {dir_perms}")
# Initialize the database
print("Calling init_db()")
db = init_db()
print("init_db() returned successfully")
# List contents of the current directory for debugging
print(f"Contents of current directory: {os.listdir(current_cwd)}")
# List contents of the .ra-aid directory for debugging
print(f"Contents of .ra-aid directory: {os.listdir(ra_aid_path_str)}")
# Check that the database file exists using os.path
assert os.path.exists(db_file_str), f"Database file not found at {db_file_str}"
assert os.path.isfile(db_file_str), f"{db_file_str} exists but is not a file"
print(f"Final database file size: {os.path.getsize(db_file_str)} bytes")
def test_init_db_creates_database_file(cleanup_db, tmp_path):
"""
Test that init_db creates the database file.
"""
# Change to the temporary directory
os.chdir(tmp_path)
# Initialize the database
init_db()
# Check that the database file was created
assert (tmp_path / ".ra-aid" / "pk.db").exists()
assert (tmp_path / ".ra-aid" / "pk.db").is_file()
def test_init_db_returns_database_connection(cleanup_db):
"""
Test that init_db returns a database connection.
"""
# Initialize the database
db = init_db()
# Check that the database connection is returned
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
def test_init_db_with_in_memory_mode(cleanup_db):
"""
Test that init_db with in_memory=True creates an in-memory database.
"""
# Initialize the database in in-memory mode
db = init_db(in_memory=True)
# Check that the database connection is returned
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is True
def test_in_memory_mode_no_directory_created(cleanup_db, tmp_path):
"""
Test that when using in-memory mode, no directory is created.
"""
# Change to the temporary directory
os.chdir(tmp_path)
# Initialize the database in in-memory mode
init_db(in_memory=True)
# Check that the .ra-aid directory was not created
# (Note: it might be created by other tests, so we can't assert it doesn't exist)
# Instead, check that the database file was not created
assert not (tmp_path / ".ra-aid" / "pk.db").exists()
def test_init_db_returns_existing_connection(cleanup_db):
"""
Test that init_db returns the existing connection if one exists.
"""
# Initialize the database
db1 = init_db()
# Initialize the database again
db2 = init_db()
# Check that the same connection is returned
assert db1 is db2
def test_init_db_reopens_closed_connection(cleanup_db):
"""
Test that init_db reopens a closed connection.
"""
# Initialize the database
db1 = init_db()
# Close the connection
db1.close()
# Initialize the database again
db2 = init_db()
# Check that the same connection is returned and it's open
assert db1 is db2
assert not db1.is_closed()
def test_get_db_initializes_connection(cleanup_db):
"""
Test that get_db initializes a connection if none exists.
"""
# Get the database connection
db = get_db()
# Check that a connection was initialized
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
def test_get_db_returns_existing_connection(cleanup_db):
"""
Test that get_db returns the existing connection if one exists.
"""
# Initialize the database
db1 = init_db()
# Get the database connection
db2 = get_db()
# Check that the same connection is returned
assert db1 is db2
def test_get_db_reopens_closed_connection(cleanup_db):
"""
Test that get_db reopens a closed connection.
"""
# Initialize the database
db = init_db()
# Close the connection
db.close()
# Get the database connection
db2 = get_db()
# Check that the same connection is returned and it's open
assert db is db2
assert not db.is_closed()
def test_get_db_handles_reopen_error(cleanup_db, monkeypatch):
"""
Test that get_db handles errors when reopening a connection.
"""
# Initialize the database
db = init_db()
# Close the connection
db.close()
# Create a patched version of the connect method that raises an error
original_connect = peewee.SqliteDatabase.connect
def mock_connect(self, *args, **kwargs):
if self is db: # Only raise for the specific db instance
raise peewee.OperationalError("Test error")
return original_connect(self, *args, **kwargs)
# Apply the patch
monkeypatch.setattr(peewee.SqliteDatabase, 'connect', mock_connect)
# Get the database connection
db2 = get_db()
# Check that a new connection was initialized
assert db is not db2
assert not db2.is_closed()
def test_close_db_closes_connection(cleanup_db):
"""
Test that close_db closes the connection.
"""
# Initialize the database
db = init_db()
# Close the connection
close_db()
# Check that the connection is closed
assert db.is_closed()
def test_close_db_handles_no_connection():
"""
Test that close_db handles the case where no connection exists.
"""
# Reset the contextvar
db_var.set(None)
# Close the connection (should not raise an error)
close_db()
def test_close_db_handles_already_closed_connection(cleanup_db):
"""
Test that close_db handles the case where the connection is already closed.
"""
# Initialize the database
db = init_db()
# Close the connection
db.close()
# Close the connection again (should not raise an error)
close_db()
@patch('ra_aid.database.connection.peewee.SqliteDatabase.close')
def test_close_db_handles_error(mock_close, cleanup_db):
"""
Test that close_db handles errors when closing the connection.
"""
# Initialize the database
init_db()
# Make close raise an error
mock_close.side_effect = peewee.DatabaseError("Test error")
# Close the connection (should not raise an error)
close_db()
def test_database_manager_context_manager(cleanup_db):
"""
Test that DatabaseManager works as a context manager.
"""
# Use the context manager
with DatabaseManager() as db:
# Check that a connection was initialized
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
# Store the connection for later
db_in_context = db
# Check that the connection is closed after exiting the context
assert db_in_context.is_closed()
def test_database_manager_with_in_memory_mode(cleanup_db):
"""
Test that DatabaseManager with in_memory=True creates an in-memory database.
"""
# Use the context manager with in_memory=True
with DatabaseManager(in_memory=True) as db:
# Check that a connection was initialized
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is True
def test_init_db_shows_message_only_once(cleanup_db, caplog):
"""
Test that init_db only shows the initialization message once.
"""
# Initialize the database
init_db(in_memory=True)
# Clear the log
caplog.clear()
# Initialize the database again
init_db(in_memory=True)
# Check that no message was logged
assert "database connection initialized" not in caplog.text.lower()
def test_init_db_sets_is_in_memory_attribute(cleanup_db):
"""
Test that init_db sets the _is_in_memory attribute.
"""
# Initialize the database with in_memory=False
db = init_db(in_memory=False)
# Check that the _is_in_memory attribute is set to False
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is False
# Reset the contextvar
db_var.set(None)
# Initialize the database with in_memory=True
db = init_db(in_memory=True)
# Check that the _is_in_memory attribute is set to True
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is True
"""
Tests for the database connection module.
"""
import os
import shutil
from pathlib import Path
import pytest
import peewee
from ra_aid.database.connection import (
init_db, get_db, close_db,
db_var, DatabaseManager, logger
)
@pytest.fixture
def cleanup_db():
"""
Fixture to clean up database connections and files between tests.
This fixture:
1. Closes any open database connection
2. Resets the contextvar
3. Cleans up the .ra-aid directory
"""
# Store the current working directory
original_cwd = os.getcwd()
# Run the test
yield
# Clean up after the test
try:
# Close any open database connection
close_db()
# Reset the contextvar
db_var.set(None)
# Clean up the .ra-aid directory if it exists
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
ra_aid_dir_str = str(ra_aid_dir.absolute())
# Check using both methods
path_exists = ra_aid_dir.exists()
os_exists = os.path.exists(ra_aid_dir_str)
print(f"Cleanup check: Path.exists={path_exists}, os.path.exists={os_exists}")
if os_exists:
# Only remove the database file, not the entire directory
db_file = os.path.join(ra_aid_dir_str, "pk.db")
if os.path.exists(db_file):
os.unlink(db_file)
# Remove WAL and SHM files if they exist
wal_file = os.path.join(ra_aid_dir_str, "pk.db-wal")
if os.path.exists(wal_file):
os.unlink(wal_file)
shm_file = os.path.join(ra_aid_dir_str, "pk.db-shm")
if os.path.exists(shm_file):
os.unlink(shm_file)
# List remaining contents for debugging
if os.path.exists(ra_aid_dir_str):
print(f"Directory contents after cleanup: {os.listdir(ra_aid_dir_str)}")
except Exception as e:
# Log but don't fail if cleanup has issues
print(f"Cleanup error (non-fatal): {str(e)}")
# Make sure we're back in the original directory
os.chdir(original_cwd)
class TestInitDb:
"""Tests for the init_db function."""
def test_init_db_default(self, cleanup_db):
"""Test init_db with default parameters."""
# Get the absolute path of the current working directory
cwd = os.getcwd()
print(f"Current working directory: {cwd}")
# Initialize the database
db = init_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is False
# Verify the database file was created using both Path and os.path methods
ra_aid_dir = Path(cwd) / ".ra-aid"
ra_aid_dir_str = str(ra_aid_dir.absolute())
# Check directory existence using both methods
path_exists = ra_aid_dir.exists()
os_exists = os.path.exists(ra_aid_dir_str)
print(f"Directory check: Path.exists={path_exists}, os.path.exists={os_exists}")
# List the contents of the current directory
print(f"Contents of {cwd}: {os.listdir(cwd)}")
# If the directory exists, list its contents
if os_exists:
print(f"Contents of {ra_aid_dir_str}: {os.listdir(ra_aid_dir_str)}")
# Use os.path for assertions to be more reliable
assert os.path.exists(ra_aid_dir_str), f"Directory {ra_aid_dir_str} does not exist"
assert os.path.isdir(ra_aid_dir_str), f"{ra_aid_dir_str} is not a directory"
db_file = os.path.join(ra_aid_dir_str, "pk.db")
assert os.path.exists(db_file), f"Database file {db_file} does not exist"
assert os.path.isfile(db_file), f"{db_file} is not a file"
def test_init_db_in_memory(self, cleanup_db):
"""Test init_db with in_memory=True."""
db = init_db(in_memory=True)
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is True
def test_init_db_reuses_connection(self, cleanup_db):
"""Test that init_db reuses an existing connection."""
db1 = init_db()
db2 = init_db()
assert db1 is db2
def test_init_db_reopens_closed_connection(self, cleanup_db):
"""Test that init_db reopens a closed connection."""
db1 = init_db()
db1.close()
assert db1.is_closed()
db2 = init_db()
assert db1 is db2
assert not db1.is_closed()
class TestGetDb:
"""Tests for the get_db function."""
def test_get_db_creates_connection(self, cleanup_db):
"""Test that get_db creates a new connection if none exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
db = get_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is False
def test_get_db_reuses_connection(self, cleanup_db):
"""Test that get_db reuses an existing connection."""
db1 = init_db()
db2 = get_db()
assert db1 is db2
def test_get_db_reopens_closed_connection(self, cleanup_db):
"""Test that get_db reopens a closed connection."""
db1 = init_db()
db1.close()
assert db1.is_closed()
db2 = get_db()
assert db1 is db2
assert not db1.is_closed()
class TestCloseDb:
"""Tests for the close_db function."""
def test_close_db(self, cleanup_db):
"""Test that close_db closes an open connection."""
db = init_db()
assert not db.is_closed()
close_db()
assert db.is_closed()
def test_close_db_no_connection(self, cleanup_db):
"""Test that close_db handles the case where no connection exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
# This should not raise an exception
close_db()
def test_close_db_already_closed(self, cleanup_db):
"""Test that close_db handles the case where the connection is already closed."""
db = init_db()
db.close()
assert db.is_closed()
# This should not raise an exception
close_db()
class TestDatabaseManager:
"""Tests for the DatabaseManager class."""
def test_database_manager_default(self, cleanup_db):
"""Test DatabaseManager with default parameters."""
with DatabaseManager() as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
assert ra_aid_dir.exists()
assert (ra_aid_dir / "pk.db").exists()
# Verify the connection is closed after exiting the context
assert db.is_closed()
def test_database_manager_in_memory(self, cleanup_db):
"""Test DatabaseManager with in_memory=True."""
with DatabaseManager(in_memory=True) as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is True
# Verify the connection is closed after exiting the context
assert db.is_closed()
def test_database_manager_exception_handling(self, cleanup_db):
"""Test that DatabaseManager properly handles exceptions."""
try:
with DatabaseManager() as db:
assert not db.is_closed()
raise ValueError("Test exception")
except ValueError:
# The exception should be propagated
pass
# Verify the connection is closed even if an exception occurred
assert db.is_closed()

94
ra_aid/database/utils.py Normal file
View File

@ -0,0 +1,94 @@
"""
Database utility functions for ra_aid.
This module provides utility functions for common database operations.
"""
import importlib
import inspect
from typing import List, Type
import peewee
from ra_aid.database.connection import get_db
from ra_aid.database.models import BaseModel
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
"""
Ensure that database tables for the specified models exist.
If no models are specified, this function will attempt to discover
all models that inherit from BaseModel.
Args:
models: Optional list of model classes to create tables for
"""
db = get_db()
if models is None:
# If no models are specified, try to discover them
models = []
try:
# Import all modules that might contain models
# This is a placeholder - in a real implementation, you would
# dynamically discover and import all modules that might contain models
from ra_aid.database import models as models_module
# Find all classes in the module that inherit from BaseModel
for name, obj in inspect.getmembers(models_module):
if (inspect.isclass(obj) and
issubclass(obj, BaseModel) and
obj != BaseModel):
models.append(obj)
except ImportError as e:
logger.warning(f"Error importing model modules: {e}")
if not models:
logger.warning("No models found to create tables for")
return
try:
with db.atomic():
db.create_tables(models, safe=True)
logger.info(f"Successfully created tables for {len(models)} models")
except peewee.DatabaseError as e:
logger.error(f"Database Error: Failed to create tables: {str(e)}")
raise
except Exception as e:
logger.error(f"Error: Failed to create tables: {str(e)}")
raise
def get_model_count(model_class: Type[BaseModel]) -> int:
"""
Get the count of records for a specific model.
Args:
model_class: The model class to count records for
Returns:
int: The number of records for the model
"""
try:
return model_class.select().count()
except peewee.DatabaseError as e:
logger.error(f"Database Error: Failed to count records: {str(e)}")
return 0
def truncate_table(model_class: Type[BaseModel]) -> None:
"""
Delete all records from a model's table.
Args:
model_class: The model class to truncate
"""
db = get_db()
try:
with db.atomic():
model_class.delete().execute()
logger.info(f"Successfully truncated table for {model_class.__name__}")
except peewee.DatabaseError as e:
logger.error(f"Database Error: Failed to truncate table: {str(e)}")
raise

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@
# This file is intentionally left empty to make the directory a Python package

1
tests/ra_aid/__init__.py Normal file
View File

@ -0,0 +1 @@
# This file is intentionally left empty to make the directory a Python package

View File

@ -0,0 +1 @@
# This file is intentionally left empty to make the directory a Python package

View File

@ -0,0 +1,181 @@
"""
Tests for the database connection module.
"""
import os
import shutil
from pathlib import Path
import pytest
import peewee
from unittest.mock import patch, MagicMock
from ra_aid.database.connection import (
init_db, get_db, close_db,
db_var, DatabaseManager, logger
)
@pytest.fixture
def cleanup_db():
"""
Fixture to clean up database connections and files between tests.
This fixture:
1. Closes any open database connection
2. Resets the contextvar
3. Cleans up the .ra-aid directory
"""
# Run the test
yield
# Clean up after the test
try:
# Close any open database connection
close_db()
# Reset the contextvar
db_var.set(None)
# Clean up the .ra-aid directory if it exists
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
if ra_aid_dir.exists():
# Only remove the database file, not the entire directory
db_file = ra_aid_dir / "pk.db"
if db_file.exists():
db_file.unlink()
# Remove WAL and SHM files if they exist
wal_file = ra_aid_dir / "pk.db-wal"
if wal_file.exists():
wal_file.unlink()
shm_file = ra_aid_dir / "pk.db-shm"
if shm_file.exists():
shm_file.unlink()
except Exception as e:
# Log but don't fail if cleanup has issues
print(f"Cleanup error (non-fatal): {str(e)}")
@pytest.fixture
def mock_logger():
"""Mock the logger to test for output messages."""
with patch('ra_aid.database.connection.logger') as mock:
yield mock
class TestInitDb:
"""Tests for the init_db function."""
def test_init_db_default(self, cleanup_db):
"""Test init_db with default parameters."""
db = init_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
assert ra_aid_dir.exists()
assert (ra_aid_dir / "pk.db").exists()
def test_init_db_in_memory(self, cleanup_db):
"""Test init_db with in_memory=True."""
db = init_db(in_memory=True)
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is True
def test_init_db_reuses_connection(self, cleanup_db):
"""Test that init_db reuses an existing connection."""
db1 = init_db()
db2 = init_db()
assert db1 is db2
def test_init_db_reopens_closed_connection(self, cleanup_db):
"""Test that init_db reopens a closed connection."""
db1 = init_db()
db1.close()
assert db1.is_closed()
db2 = init_db()
assert db1 is db2
assert not db1.is_closed()
class TestGetDb:
"""Tests for the get_db function."""
def test_get_db_creates_connection(self, cleanup_db):
"""Test that get_db creates a new connection if none exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
db = get_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is False
def test_get_db_reuses_connection(self, cleanup_db):
"""Test that get_db reuses an existing connection."""
db1 = init_db()
db2 = get_db()
assert db1 is db2
def test_get_db_reopens_closed_connection(self, cleanup_db):
"""Test that get_db reopens a closed connection."""
db1 = init_db()
db1.close()
assert db1.is_closed()
db2 = get_db()
assert db1 is db2
assert not db1.is_closed()
class TestCloseDb:
"""Tests for the close_db function."""
def test_close_db(self, cleanup_db):
"""Test that close_db closes an open connection."""
db = init_db()
assert not db.is_closed()
close_db()
assert db.is_closed()
def test_close_db_no_connection(self, cleanup_db):
"""Test that close_db handles the case where no connection exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
# This should not raise an exception
close_db()
def test_close_db_already_closed(self, cleanup_db):
"""Test that close_db handles the case where the connection is already closed."""
db = init_db()
db.close()
assert db.is_closed()
# This should not raise an exception
close_db()
class TestDatabaseManager:
"""Tests for the DatabaseManager class."""
def test_database_manager_default(self, cleanup_db):
"""Test DatabaseManager with default parameters."""
with DatabaseManager() as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
assert ra_aid_dir.exists()
assert (ra_aid_dir / "pk.db").exists()
# Verify the connection is closed after exiting the context
assert db.is_closed()
def test_database_manager_in_memory(self, cleanup_db):
"""Test DatabaseManager with in_memory=True."""
with DatabaseManager(in_memory=True) as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, '_is_in_memory')
assert db._is_in_memory is True
# Verify the connection is closed after exiting the context
assert db.is_closed()
def test_database_manager_exception_handling(self, cleanup_db):
"""Test that DatabaseManager properly handles exceptions."""
try:
with DatabaseManager() as db:
assert not db.is_closed()
raise ValueError("Test exception")
except ValueError:
# The exception should be propagated
pass
# Verify the connection is closed even if an exception occurred
assert db.is_closed()

View File

@ -0,0 +1,132 @@
"""
Tests for the database models module.
"""
from unittest.mock import patch
import pytest
import peewee
from ra_aid.database.models import BaseModel
from ra_aid.database.connection import (
db_var, get_db, init_db, close_db
)
@pytest.fixture
def cleanup_db():
"""Reset the database contextvar and connection state after each test."""
# Reset before the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
@pytest.fixture
def setup_test_model(cleanup_db):
"""Set up a test model class for testing."""
# Initialize an in-memory database connection
db = init_db(in_memory=True)
# Define a test model
class TestModel(BaseModel):
name = peewee.CharField()
value = peewee.IntegerField(null=True)
class Meta:
database = db
# Create the table
with db.atomic():
db.create_tables([TestModel], safe=True)
yield TestModel
# Clean up
with db.atomic():
TestModel.drop_table(safe=True)
def test_base_model_save_updates_timestamps(setup_test_model):
"""Test that saving a model updates the timestamps."""
TestModel = setup_test_model
# Create a new instance
instance = TestModel(name="test", value=42)
instance.save()
# Check that created_at and updated_at are set
assert instance.created_at is not None
assert instance.updated_at is not None
# Store the original timestamps
original_created_at = instance.created_at
original_updated_at = instance.updated_at
# Wait a moment to ensure timestamps would be different
import time
time.sleep(0.001)
# Update the instance
instance.value = 43
instance.save()
# Check that updated_at changed but created_at didn't
assert instance.created_at == original_created_at
assert instance.updated_at != original_updated_at
def test_base_model_get_or_create(setup_test_model):
"""Test the get_or_create method."""
TestModel = setup_test_model
# First call should create a new instance
instance, created = TestModel.get_or_create(name="test", value=42)
assert created is True
assert instance.name == "test"
assert instance.value == 42
# Second call with same parameters should return existing instance
instance2, created2 = TestModel.get_or_create(name="test", value=42)
assert created2 is False
assert instance2.id == instance.id
# Call with different parameters should create a new instance
instance3, created3 = TestModel.get_or_create(name="test2", value=43)
assert created3 is True
assert instance3.id != instance.id
@patch('ra_aid.database.models.logger')
def test_base_model_get_or_create_handles_errors(mock_logger, setup_test_model):
"""Test that get_or_create handles database errors properly."""
TestModel = setup_test_model
# Mock the parent get_or_create to raise a DatabaseError
with patch('peewee.Model.get_or_create', side_effect=peewee.DatabaseError("Test error")):
# Call should raise the error
with pytest.raises(peewee.DatabaseError):
TestModel.get_or_create(name="test")
# Verify error was logged
mock_logger.error.assert_called_with("Failed in get_or_create: Test error")

View File

@ -0,0 +1,220 @@
"""
Tests for the database utils module.
"""
from unittest.mock import patch, MagicMock
import pytest
import peewee
from ra_aid.database.connection import (
db_var, get_db, init_db, close_db
)
from ra_aid.database.models import BaseModel
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
@pytest.fixture
def cleanup_db():
"""Reset the database contextvar and connection state after each test."""
# Reset before the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
@pytest.fixture
def mock_logger():
"""Mock the logger to test for output messages."""
with patch('ra_aid.database.utils.logger') as mock:
yield mock
@pytest.fixture
def setup_test_model(cleanup_db):
"""Set up a test model for database tests."""
# Initialize the database in memory
db = init_db(in_memory=True)
# Define a test model class
class TestModel(BaseModel):
name = peewee.CharField(max_length=100)
value = peewee.IntegerField(default=0)
class Meta:
database = db
# Create the test table in a transaction
with db.atomic():
db.create_tables([TestModel], safe=True)
# Yield control to the test
yield TestModel
# Clean up: drop the test table
with db.atomic():
db.drop_tables([TestModel], safe=True)
def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
"""Test ensure_tables_created with explicit models."""
# Initialize the database in memory
db = init_db(in_memory=True)
# Define a test model that uses this database
class TestModel(BaseModel):
name = peewee.CharField(max_length=100)
value = peewee.IntegerField(default=0)
class Meta:
database = db
# Call ensure_tables_created with explicit models
ensure_tables_created([TestModel])
# Verify success message was logged
mock_logger.info.assert_called_with("Successfully created tables for 1 models")
# Verify the table exists by trying to use it
TestModel.create(name="test", value=42)
count = TestModel.select().count()
assert count == 1
def test_ensure_tables_created_no_models(cleanup_db, mock_logger):
"""Test ensure_tables_created with no models."""
# Initialize the database in memory
db = init_db(in_memory=True)
# Mock the import to simulate no models found
with patch('ra_aid.database.utils.importlib.import_module', side_effect=ImportError("No module")):
# Call ensure_tables_created with no models
ensure_tables_created()
# Verify warning message was logged
mock_logger.warning.assert_called_with("No models found to create tables for")
@patch('ra_aid.database.utils.get_db')
def test_ensure_tables_created_database_error(mock_get_db, setup_test_model, cleanup_db, mock_logger):
"""Test ensure_tables_created handles database errors."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Create a mock database with a create_tables method that raises an error
mock_db = MagicMock()
mock_db.atomic.return_value.__enter__.return_value = None
mock_db.atomic.return_value.__exit__.return_value = None
mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error")
# Configure get_db to return our mock
mock_get_db.return_value = mock_db
# Call ensure_tables_created and expect an exception
with pytest.raises(peewee.DatabaseError):
ensure_tables_created([TestModel])
# Verify error message was logged
mock_logger.error.assert_called_with("Database Error: Failed to create tables: Test database error")
def test_get_model_count(setup_test_model, mock_logger):
"""Test get_model_count returns the correct count."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# First ensure the table is empty
TestModel.delete().execute()
# Create some test records
TestModel.create(name="test1", value=1)
TestModel.create(name="test2", value=2)
# Call get_model_count
count = get_model_count(TestModel)
# Verify the count is correct
assert count == 2
@patch('peewee.ModelSelect.count')
def test_get_model_count_database_error(mock_count, setup_test_model, mock_logger):
"""Test get_model_count handles database errors."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Configure the mock to raise a DatabaseError
mock_count.side_effect = peewee.DatabaseError("Test count error")
# Call get_model_count
count = get_model_count(TestModel)
# Verify error message was logged
mock_logger.error.assert_called_with("Database Error: Failed to count records: Test count error")
# Verify the function returns 0 on error
assert count == 0
def test_truncate_table(setup_test_model, mock_logger):
"""Test truncate_table deletes all records."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Create some test records
TestModel.create(name="test1", value=1)
TestModel.create(name="test2", value=2)
# Verify records exist
assert TestModel.select().count() == 2
# Call truncate_table
truncate_table(TestModel)
# Verify success message was logged
mock_logger.info.assert_called_with(f"Successfully truncated table for {TestModel.__name__}")
# Verify all records were deleted
assert TestModel.select().count() == 0
@patch('ra_aid.database.models.BaseModel.delete')
def test_truncate_table_database_error(mock_delete, setup_test_model, mock_logger):
"""Test truncate_table handles database errors."""
# Get the TestModel class from the fixture
TestModel = setup_test_model
# Create a test record
TestModel.create(name="test", value=42)
# Configure the mock to return a mock query with execute that raises an error
mock_query = MagicMock()
mock_query.execute.side_effect = peewee.DatabaseError("Test delete error")
mock_delete.return_value = mock_query
# Call truncate_table and expect an exception
with pytest.raises(peewee.DatabaseError):
truncate_table(TestModel)
# Verify error message was logged
mock_logger.error.assert_called_with("Database Error: Failed to truncate table: Test delete error")