From dbf4d954e15c200f19f593e2fba784e619dee3ec Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Wed, 26 Feb 2025 11:44:18 -0500 Subject: [PATCH] base db infra --- ra_aid/__main__.py | 326 ++++----- ra_aid/database/__init__.py | 26 + ra_aid/database/connection.py | 371 ++++++++++ ra_aid/database/models.py | 62 ++ .../tests/ra_aid/database/test_connection.py | 653 ++++++++++++++++++ ra_aid/database/utils.py | 94 +++ tests/__init__.py | 1 + tests/ra_aid/__init__.py | 1 + tests/ra_aid/database/__init__.py | 1 + tests/ra_aid/database/test_connection.py | 181 +++++ tests/ra_aid/database/test_models.py | 132 ++++ tests/ra_aid/database/test_utils.py | 220 ++++++ 12 files changed, 1906 insertions(+), 162 deletions(-) create mode 100644 ra_aid/database/__init__.py create mode 100644 ra_aid/database/connection.py create mode 100644 ra_aid/database/models.py create mode 100644 ra_aid/database/tests/ra_aid/database/test_connection.py create mode 100644 ra_aid/database/utils.py create mode 100644 tests/__init__.py create mode 100644 tests/ra_aid/__init__.py create mode 100644 tests/ra_aid/database/__init__.py create mode 100644 tests/ra_aid/database/test_connection.py create mode 100644 tests/ra_aid/database/test_models.py create mode 100644 tests/ra_aid/database/test_utils.py diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 94d00c9..b8191a9 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -17,6 +17,7 @@ from ra_aid.agent_utils import ( run_planning_agent, run_research_agent, ) +from ra_aid.database import init_db, close_db, DatabaseManager from ra_aid.config import ( DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, @@ -333,201 +334,202 @@ def main(): return try: - # Check dependencies before proceeding - check_dependencies() + with DatabaseManager() as db: + # Check dependencies before proceeding + check_dependencies() - expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( - validate_environment(args) - ) # Will exit if main env vars missing - logger.debug("Environment validation successful") + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) # Will exit if main env vars missing + logger.debug("Environment validation successful") - # Validate model configuration early + # Validate model configuration early + model_config = models_params.get(args.provider, {}).get(args.model or "", {}) + supports_temperature = model_config.get( + "supports_temperature", + args.provider + in ["anthropic", "openai", "openrouter", "openai-compatible", "deepseek"], + ) - model_config = models_params.get(args.provider, {}).get(args.model or "", {}) - supports_temperature = model_config.get( - "supports_temperature", - args.provider - in ["anthropic", "openai", "openrouter", "openai-compatible", "deepseek"], - ) - - if supports_temperature and args.temperature is None: - args.temperature = model_config.get("default_temperature") - if args.temperature is None: - cpm( - f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}." + if supports_temperature and args.temperature is None: + args.temperature = model_config.get("default_temperature") + if args.temperature is None: + cpm( + f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}." + ) + args.temperature = DEFAULT_TEMPERATURE + logger.debug( + f"Using default temperature {args.temperature} for model {args.model}" ) - args.temperature = DEFAULT_TEMPERATURE - logger.debug( - f"Using default temperature {args.temperature} for model {args.model}" + + status = build_status(args, expert_enabled, web_research_enabled) + + console.print( + Panel(status, title=f"RA.Aid v{__version__}", border_style="bright_blue", padding=(0, 1)) ) - status = build_status(args, expert_enabled, web_research_enabled) + # Handle chat mode + if args.chat: + # Initialize chat model with default provider/model + chat_model = initialize_llm( + args.provider, args.model, temperature=args.temperature + ) + + if args.research_only: + print_error("Chat mode cannot be used with --research-only") + sys.exit(1) - console.print( - Panel(status, title=f"RA.Aid v{__version__}", border_style="bright_blue", padding=(0, 1)) - ) + print_stage_header("Chat Mode") - # Handle chat mode - if args.chat: - # Initialize chat model with default provider/model - chat_model = initialize_llm( - args.provider, args.model, temperature=args.temperature - ) - if args.research_only: - print_error("Chat mode cannot be used with --research-only") + # Get project info + try: + project_info = get_project_info(".", file_limit=2000) + formatted_project_info = format_project_info(project_info) + except Exception as e: + logger.warning(f"Failed to get project info: {e}") + formatted_project_info = "" + + # Get initial request from user + initial_request = ask_human.invoke( + {"question": "What would you like help with?"} + ) + + # Get working directory and current date + working_directory = os.getcwd() + current_date = datetime.now().strftime("%Y-%m-%d") + + # Run chat agent with CHAT_PROMPT + config = { + "configurable": {"thread_id": str(uuid.uuid4())}, + "recursion_limit": args.recursion_limit, + "chat_mode": True, + "cowboy_mode": args.cowboy_mode, + "hil": True, # Always true in chat mode + "web_research_enabled": web_research_enabled, + "initial_request": initial_request, + "limit_tokens": args.disable_limit_tokens, + } + + # Store config in global memory + _global_memory["config"] = config + _global_memory["config"]["provider"] = args.provider + _global_memory["config"]["model"] = args.model + _global_memory["config"]["expert_provider"] = args.expert_provider + _global_memory["config"]["expert_model"] = args.expert_model + _global_memory["config"]["temperature"] = args.temperature + + # Create chat agent with appropriate tools + chat_agent = create_agent( + chat_model, + get_chat_tools( + expert_enabled=expert_enabled, + web_research_enabled=web_research_enabled, + ), + checkpointer=MemorySaver(), + ) + + # Run chat agent and exit + run_agent_with_retry( + chat_agent, + CHAT_PROMPT.format( + initial_request=initial_request, + web_research_section=( + WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else "" + ), + working_directory=working_directory, + current_date=current_date, + project_info=formatted_project_info, + ), + config, + ) + return + + # Validate message is provided + if not args.message: + print_error("--message is required") sys.exit(1) - print_stage_header("Chat Mode") - - # Get project info - try: - project_info = get_project_info(".", file_limit=2000) - formatted_project_info = format_project_info(project_info) - except Exception as e: - logger.warning(f"Failed to get project info: {e}") - formatted_project_info = "" - - # Get initial request from user - initial_request = ask_human.invoke( - {"question": "What would you like help with?"} - ) - - # Get working directory and current date - working_directory = os.getcwd() - current_date = datetime.now().strftime("%Y-%m-%d") - - # Run chat agent with CHAT_PROMPT + base_task = args.message config = { "configurable": {"thread_id": str(uuid.uuid4())}, "recursion_limit": args.recursion_limit, - "chat_mode": True, + "research_only": args.research_only, "cowboy_mode": args.cowboy_mode, - "hil": True, # Always true in chat mode "web_research_enabled": web_research_enabled, - "initial_request": initial_request, + "aider_config": args.aider_config, "limit_tokens": args.disable_limit_tokens, + "auto_test": args.auto_test, + "test_cmd": args.test_cmd, + "max_test_cmd_retries": args.max_test_cmd_retries, + "experimental_fallback_handler": args.experimental_fallback_handler, + "test_cmd_timeout": args.test_cmd_timeout, } - # Store config in global memory + # Store config in global memory for access by is_informational_query _global_memory["config"] = config + + # Store base provider/model configuration _global_memory["config"]["provider"] = args.provider _global_memory["config"]["model"] = args.model + + # Store expert provider/model (no fallback) _global_memory["config"]["expert_provider"] = args.expert_provider _global_memory["config"]["expert_model"] = args.expert_model + + # Store planner config with fallback to base values + _global_memory["config"]["planner_provider"] = ( + args.planner_provider or args.provider + ) + _global_memory["config"]["planner_model"] = args.planner_model or args.model + + # Store research config with fallback to base values + _global_memory["config"]["research_provider"] = ( + args.research_provider or args.provider + ) + _global_memory["config"]["research_model"] = args.research_model or args.model + + # Store temperature in global config _global_memory["config"]["temperature"] = args.temperature - # Create chat agent with appropriate tools - chat_agent = create_agent( - chat_model, - get_chat_tools( - expert_enabled=expert_enabled, - web_research_enabled=web_research_enabled, - ), - checkpointer=MemorySaver(), + # Run research stage + print_stage_header("Research Stage") + + # Initialize research model with potential overrides + research_provider = args.research_provider or args.provider + research_model_name = args.research_model or args.model + research_model = initialize_llm( + research_provider, research_model_name, temperature=args.temperature ) - # Run chat agent and exit - run_agent_with_retry( - chat_agent, - CHAT_PROMPT.format( - initial_request=initial_request, - web_research_section=( - WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else "" - ), - working_directory=working_directory, - current_date=current_date, - project_info=formatted_project_info, - ), - config, - ) - return - - # Validate message is provided - if not args.message: - print_error("--message is required") - sys.exit(1) - - base_task = args.message - config = { - "configurable": {"thread_id": str(uuid.uuid4())}, - "recursion_limit": args.recursion_limit, - "research_only": args.research_only, - "cowboy_mode": args.cowboy_mode, - "web_research_enabled": web_research_enabled, - "aider_config": args.aider_config, - "limit_tokens": args.disable_limit_tokens, - "auto_test": args.auto_test, - "test_cmd": args.test_cmd, - "max_test_cmd_retries": args.max_test_cmd_retries, - "experimental_fallback_handler": args.experimental_fallback_handler, - "test_cmd_timeout": args.test_cmd_timeout, - } - - # Store config in global memory for access by is_informational_query - _global_memory["config"] = config - - # Store base provider/model configuration - _global_memory["config"]["provider"] = args.provider - _global_memory["config"]["model"] = args.model - - # Store expert provider/model (no fallback) - _global_memory["config"]["expert_provider"] = args.expert_provider - _global_memory["config"]["expert_model"] = args.expert_model - - # Store planner config with fallback to base values - _global_memory["config"]["planner_provider"] = ( - args.planner_provider or args.provider - ) - _global_memory["config"]["planner_model"] = args.planner_model or args.model - - # Store research config with fallback to base values - _global_memory["config"]["research_provider"] = ( - args.research_provider or args.provider - ) - _global_memory["config"]["research_model"] = args.research_model or args.model - - # Store temperature in global config - _global_memory["config"]["temperature"] = args.temperature - - # Run research stage - print_stage_header("Research Stage") - - # Initialize research model with potential overrides - research_provider = args.research_provider or args.provider - research_model_name = args.research_model or args.model - research_model = initialize_llm( - research_provider, research_model_name, temperature=args.temperature - ) - - run_research_agent( - base_task, - research_model, - expert_enabled=expert_enabled, - research_only=args.research_only, - hil=args.hil, - memory=research_memory, - config=config, - ) - - # Proceed with planning and implementation if not an informational query - if not is_informational_query(): - # Initialize planning model with potential overrides - planner_provider = args.planner_provider or args.provider - planner_model_name = args.planner_model or args.model - planning_model = initialize_llm( - planner_provider, planner_model_name, temperature=args.temperature - ) - - # Run planning agent - run_planning_agent( + run_research_agent( base_task, - planning_model, + research_model, expert_enabled=expert_enabled, + research_only=args.research_only, hil=args.hil, - memory=planning_memory, + memory=research_memory, config=config, ) + # Proceed with planning and implementation if not an informational query + if not is_informational_query(): + # Initialize planning model with potential overrides + planner_provider = args.planner_provider or args.provider + planner_model_name = args.planner_model or args.model + planning_model = initialize_llm( + planner_provider, planner_model_name, temperature=args.temperature + ) + + # Run planning agent + run_planning_agent( + base_task, + planning_model, + expert_enabled=expert_enabled, + hil=args.hil, + memory=planning_memory, + config=config, + ) + except (KeyboardInterrupt, AgentInterrupt): print() print(" 👋 Bye!") diff --git a/ra_aid/database/__init__.py b/ra_aid/database/__init__.py new file mode 100644 index 0000000..c4f9631 --- /dev/null +++ b/ra_aid/database/__init__.py @@ -0,0 +1,26 @@ +""" +Database package for ra_aid. + +This package provides database functionality for the ra_aid application, +including connection management, models, and utility functions. +""" + +from ra_aid.database.connection import ( + init_db, + get_db, + close_db, + DatabaseManager +) +from ra_aid.database.models import BaseModel +from ra_aid.database.utils import get_model_count, truncate_table, ensure_tables_created + +__all__ = [ + 'init_db', + 'get_db', + 'close_db', + 'DatabaseManager', + 'BaseModel', + 'get_model_count', + 'truncate_table', + 'ensure_tables_created', +] diff --git a/ra_aid/database/connection.py b/ra_aid/database/connection.py new file mode 100644 index 0000000..bfba7d0 --- /dev/null +++ b/ra_aid/database/connection.py @@ -0,0 +1,371 @@ +""" +Database connection management for ra_aid. + +This module provides functions to initialize, get, and close database connections. +It also provides a context manager for database connections. +""" + +import os +import contextvars +from pathlib import Path +from typing import Optional, Any + +import peewee + +from ra_aid.logging_config import get_logger + +# Create contextvar to hold the database connection +db_var = contextvars.ContextVar("db", default=None) +logger = get_logger(__name__) + + +class DatabaseManager: + """ + Context manager for database connections. + + This class provides a context manager interface for database connections, + using the existing contextvars approach internally. + + Example: + with DatabaseManager() as db: + # Use the database connection + db.execute_sql("SELECT * FROM table") + + # Or with in-memory database: + with DatabaseManager(in_memory=True) as db: + # Use in-memory database + """ + + def __init__(self, in_memory: bool = False): + """ + Initialize the DatabaseManager. + + Args: + in_memory: Whether to use an in-memory database (default: False) + """ + self.in_memory = in_memory + + def __enter__(self) -> peewee.SqliteDatabase: + """ + Initialize the database connection and return it. + + Returns: + peewee.SqliteDatabase: The initialized database connection + """ + return init_db(in_memory=self.in_memory) + + def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception], + exc_tb: Optional[Any]) -> None: + """ + Close the database connection when exiting the context. + + Args: + exc_type: The exception type if an exception was raised + exc_val: The exception value if an exception was raised + exc_tb: The traceback if an exception was raised + """ + close_db() + + # Don't suppress exceptions + return False + +def init_db(in_memory: bool = False) -> peewee.SqliteDatabase: + """ + Initialize the database connection. + + Creates the .ra-aid directory if it doesn't exist and initializes + the SQLite database connection. If a database connection already exists, + returns the existing connection instead of creating a new one. + + Args: + in_memory: Whether to use an in-memory database (default: False) + + Returns: + peewee.SqliteDatabase: The initialized database connection + """ + # Check if a database connection already exists + existing_db = db_var.get() + if existing_db is not None: + # If the connection exists but is closed, reopen it + if existing_db.is_closed(): + try: + existing_db.connect() + except peewee.OperationalError as e: + logger.error(f"Failed to reopen existing database connection: {str(e)}") + # Continue to create a new connection if reopening fails + else: + return existing_db + else: + # Connection exists and is open, return it + return existing_db + + # Set up database path + if in_memory: + # Use in-memory database + db_path = ":memory:" + logger.debug("Using in-memory SQLite database") + else: + # Get current working directory and create .ra-aid directory if it doesn't exist + cwd = os.getcwd() + logger.debug(f"Current working directory: {cwd}") + print(f"DIRECT DEBUG: Current working directory: {cwd}") + + # Define the .ra-aid directory path + ra_aid_dir_str = os.path.join(cwd, ".ra-aid") + ra_aid_dir = Path(ra_aid_dir_str) + ra_aid_dir = ra_aid_dir.absolute() # Ensure we have the absolute path + ra_aid_dir_str = str(ra_aid_dir) # Update string representation with absolute path + + logger.debug(f"Creating database directory at: {ra_aid_dir_str}") + print(f"DIRECT DEBUG: Creating database directory at: {ra_aid_dir_str}") + + # Multiple approaches to ensure directory creation + directory_created = False + error_messages = [] + + # Approach 1: Try os.mkdir directly + if not os.path.exists(ra_aid_dir_str): + try: + logger.debug("Attempting directory creation with os.mkdir") + print(f"DIRECT DEBUG: Attempting directory creation with os.mkdir: {ra_aid_dir_str}") + os.mkdir(ra_aid_dir_str, mode=0o755) + directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str) + if directory_created: + logger.debug("Directory created successfully with os.mkdir") + print(f"DIRECT DEBUG: Directory created successfully with os.mkdir") + except Exception as e: + error_msg = f"os.mkdir failed: {str(e)}" + logger.debug(error_msg) + print(f"DIRECT DEBUG: {error_msg}") + error_messages.append(error_msg) + else: + logger.debug("Directory already exists, skipping creation") + print(f"DIRECT DEBUG: Directory already exists at {ra_aid_dir_str}") + directory_created = True + + # Approach 2: Try os.makedirs if os.mkdir failed + if not directory_created: + try: + logger.debug("Attempting directory creation with os.makedirs") + print(f"DIRECT DEBUG: Attempting directory creation with os.makedirs: {ra_aid_dir_str}") + os.makedirs(ra_aid_dir_str, exist_ok=True, mode=0o755) + directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str) + if directory_created: + logger.debug("Directory created successfully with os.makedirs") + print(f"DIRECT DEBUG: Directory created successfully with os.makedirs") + except Exception as e: + error_msg = f"os.makedirs failed: {str(e)}" + logger.debug(error_msg) + print(f"DIRECT DEBUG: {error_msg}") + error_messages.append(error_msg) + + # Approach 3: Try Path.mkdir if previous methods failed + if not directory_created: + try: + logger.debug("Attempting directory creation with Path.mkdir") + print(f"DIRECT DEBUG: Attempting directory creation with Path.mkdir: {ra_aid_dir}") + ra_aid_dir.mkdir(mode=0o755, parents=True, exist_ok=True) + directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str) + if directory_created: + logger.debug("Directory created successfully with Path.mkdir") + print(f"DIRECT DEBUG: Directory created successfully with Path.mkdir") + except Exception as e: + error_msg = f"Path.mkdir failed: {str(e)}" + logger.debug(error_msg) + print(f"DIRECT DEBUG: {error_msg}") + error_messages.append(error_msg) + + # Verify the directory was actually created + path_exists = ra_aid_dir.exists() + os_exists = os.path.exists(ra_aid_dir_str) + is_dir = os.path.isdir(ra_aid_dir_str) if os_exists else False + + logger.debug(f"Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}") + print(f"DIRECT DEBUG: Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}") + + # Check parent directory permissions and contents for debugging + try: + parent_dir = os.path.dirname(ra_aid_dir_str) + parent_perms = oct(os.stat(parent_dir).st_mode)[-3:] + parent_contents = os.listdir(parent_dir) + logger.debug(f"Parent directory {parent_dir} permissions: {parent_perms}") + logger.debug(f"Parent directory contents: {parent_contents}") + print(f"DIRECT DEBUG: Parent directory {parent_dir} permissions: {parent_perms}") + print(f"DIRECT DEBUG: Parent directory contents: {parent_contents}") + except Exception as e: + logger.debug(f"Could not check parent directory: {str(e)}") + print(f"DIRECT DEBUG: Could not check parent directory: {str(e)}") + + if not os_exists or not is_dir: + error_msg = f"Directory does not exist or is not a directory after creation attempts: {ra_aid_dir_str}" + logger.error(error_msg) + print(f"DIRECT DEBUG ERROR: {error_msg}") + if error_messages: + logger.error(f"Previous errors: {', '.join(error_messages)}") + print(f"DIRECT DEBUG ERROR: Previous errors: {', '.join(error_messages)}") + raise FileNotFoundError(f"Failed to create directory: {ra_aid_dir_str}") + + # Check directory permissions + try: + permissions = oct(os.stat(ra_aid_dir_str).st_mode)[-3:] + logger.debug(f"Directory created/verified: {ra_aid_dir_str} with permissions {permissions}") + print(f"DIRECT DEBUG: Directory created/verified: {ra_aid_dir_str} with permissions {permissions}") + + # List directory contents for debugging + dir_contents = os.listdir(ra_aid_dir_str) + logger.debug(f"Directory contents: {dir_contents}") + print(f"DIRECT DEBUG: Directory contents: {dir_contents}") + except Exception as e: + logger.debug(f"Could not check directory details: {str(e)}") + print(f"DIRECT DEBUG: Could not check directory details: {str(e)}") + + # Database path for file-based database - use os.path.join for maximum compatibility + db_path = os.path.join(ra_aid_dir_str, "pk.db") + logger.debug(f"Database path: {db_path}") + print(f"DIRECT DEBUG: Database path: {db_path}") + + try: + # For file-based databases, ensure the file exists or can be created + if db_path != ":memory:": + # Check if the database file exists + db_file_exists = os.path.exists(db_path) + logger.debug(f"Database file exists check: {db_file_exists}") + print(f"DIRECT DEBUG: Database file exists check: {db_file_exists}") + + # If the file doesn't exist, try to create an empty file to ensure we have write permissions + if not db_file_exists: + try: + logger.debug(f"Creating empty database file at: {db_path}") + print(f"DIRECT DEBUG: Creating empty database file at: {db_path}") + with open(db_path, 'w') as f: + pass # Create empty file + + # Verify the file was created + if os.path.exists(db_path): + logger.debug("Empty database file created successfully") + print(f"DIRECT DEBUG: Empty database file created successfully") + else: + logger.error(f"Failed to create database file at: {db_path}") + print(f"DIRECT DEBUG ERROR: Failed to create database file at: {db_path}") + except Exception as e: + logger.error(f"Error creating database file: {str(e)}") + print(f"DIRECT DEBUG ERROR: Error creating database file: {str(e)}") + # Continue anyway, as SQLite might be able to create the file itself + + # Initialize the database connection + logger.debug(f"Initializing SQLite database at: {db_path}") + print(f"DIRECT DEBUG: Initializing SQLite database at: {db_path}") + db = peewee.SqliteDatabase( + db_path, + pragmas={ + 'journal_mode': 'wal', # Write-Ahead Logging for better concurrency + 'foreign_keys': 1, # Enforce foreign key constraints + 'cache_size': -1024 * 32, # 32MB cache + } + ) + + # Always explicitly connect to ensure the connection is established + if db.is_closed(): + logger.debug("Explicitly connecting to database") + print(f"DIRECT DEBUG: Explicitly connecting to database") + db.connect() + + # Store the database connection in the contextvar + db_var.set(db) + + # Store whether this is an in-memory database (for backward compatibility) + db._is_in_memory = in_memory + + # Verify the database is usable by executing a simple query + if not in_memory: + try: + db.execute_sql("SELECT 1") + logger.debug("Database connection verified with test query") + print(f"DIRECT DEBUG: Database connection verified with test query") + + # Check if the database file exists after initialization + db_file_exists = os.path.exists(db_path) + db_file_size = os.path.getsize(db_path) if db_file_exists else 0 + logger.debug(f"Database file check after init: exists={db_file_exists}, size={db_file_size} bytes") + print(f"DIRECT DEBUG: Database file check after init: exists={db_file_exists}, size={db_file_size} bytes") + except Exception as e: + logger.error(f"Database verification failed: {str(e)}") + print(f"DIRECT DEBUG ERROR: Database verification failed: {str(e)}") + # Continue anyway, as this is just a verification step + + # Only show initialization message if it hasn't been shown before + if not hasattr(db, '_message_shown') or not db._message_shown: + if in_memory: + logger.debug("In-memory database connection initialized successfully") + else: + logger.debug("Database connection initialized successfully") + db._message_shown = True + + return db + except peewee.OperationalError as e: + logger.error(f"Database Operational Error: {str(e)}") + raise + except peewee.DatabaseError as e: + logger.error(f"Database Error: {str(e)}") + raise + except Exception as e: + logger.error(f"Failed to initialize database: {str(e)}") + raise + +def get_db() -> peewee.SqliteDatabase: + """ + Get the current database connection. + + If no connection exists, initializes a new one. + If connection exists but is closed, reopens it. + + Returns: + peewee.SqliteDatabase: The current database connection + """ + db = db_var.get() + + if db is None: + # No database connection exists, initialize one + # Use the default in-memory mode (False) + return init_db(in_memory=False) + + # Check if connection is closed and reopen if needed + if db.is_closed(): + try: + logger.debug("Attempting to reopen closed database connection") + db.connect() + logger.info("Reopened existing database connection") + except peewee.OperationalError as e: + logger.error(f"Failed to reopen database connection: {str(e)}") + # Create a completely new connection + # First, remove the old connection from the context var + db_var.set(None) + # Then initialize a new connection with the same in-memory setting + in_memory = hasattr(db, '_is_in_memory') and db._is_in_memory + logger.debug(f"Creating new database connection (in_memory={in_memory})") + # Create a completely new database object, don't reuse the old one + return init_db(in_memory=in_memory) + + return db + +def close_db() -> None: + """ + Close the current database connection if it exists. + + Handles various error conditions gracefully. + """ + db = db_var.get() + if db is None: + logger.warning("No database connection to close") + return + + try: + if not db.is_closed(): + db.close() + logger.info("Database connection closed successfully") + else: + logger.warning("Database connection was already closed") + except peewee.DatabaseError as e: + logger.error(f"Database Error: Failed to close connection: {str(e)}") + except Exception as e: + logger.error(f"Failed to close database connection: {str(e)}") diff --git a/ra_aid/database/models.py b/ra_aid/database/models.py new file mode 100644 index 0000000..13ee099 --- /dev/null +++ b/ra_aid/database/models.py @@ -0,0 +1,62 @@ +""" +Database models for ra_aid. + +This module defines the base model class that all models will inherit from. +""" + +import datetime +from typing import Any, Dict, Type, TypeVar + +import peewee + +from ra_aid.database.connection import get_db +from ra_aid.logging_config import get_logger + +T = TypeVar('T', bound='BaseModel') +logger = get_logger(__name__) + +class BaseModel(peewee.Model): + """ + Base model class for all ra_aid models. + + All models should inherit from this class to ensure consistent + behavior and database connection. + """ + created_at = peewee.DateTimeField(default=datetime.datetime.now) + updated_at = peewee.DateTimeField(default=datetime.datetime.now) + + class Meta: + database = get_db() + + def save(self, *args: Any, **kwargs: Any) -> int: + """ + Override save to update the updated_at field. + + Args: + *args: Arguments to pass to the parent save method + **kwargs: Keyword arguments to pass to the parent save method + + Returns: + int: The primary key of the saved instance + """ + self.updated_at = datetime.datetime.now() + return super().save(*args, **kwargs) + + @classmethod + def get_or_create(cls: Type[T], **kwargs: Any) -> tuple[T, bool]: + """ + Get an instance or create it if it doesn't exist. + + Args: + **kwargs: Fields to use for lookup and creation + + Returns: + tuple: (instance, created) where created is a boolean indicating + whether a new instance was created + """ + try: + return super().get_or_create(**kwargs) + except peewee.DatabaseError as e: + # Log the error with logger + logger.error(f"Failed in get_or_create: {str(e)}") + raise diff --git a/ra_aid/database/tests/ra_aid/database/test_connection.py b/ra_aid/database/tests/ra_aid/database/test_connection.py new file mode 100644 index 0000000..904006e --- /dev/null +++ b/ra_aid/database/tests/ra_aid/database/test_connection.py @@ -0,0 +1,653 @@ +""" +Tests for the database connection module. +""" + +import os +import pytest +from pathlib import Path +from unittest.mock import patch, MagicMock + +import peewee + +from ra_aid.database.connection import ( + init_db, get_db, close_db, db_var, DatabaseManager +) + + +@pytest.fixture +def cleanup_db(): + """ + Fixture to clean up database connections after tests. + """ + # Run the test + yield + + # Clean up after the test + db = db_var.get() + if db is not None: + if hasattr(db, '_is_in_memory'): + delattr(db, '_is_in_memory') + if hasattr(db, '_message_shown'): + delattr(db, '_message_shown') + if not db.is_closed(): + db.close() + + # Reset the contextvar + db_var.set(None) + + +@pytest.fixture +def setup_in_memory_db(): + """ + Fixture to set up an in-memory database for testing. + """ + # Initialize in-memory database + db = init_db(in_memory=True) + + # Run the test + yield db + + # Clean up + if not db.is_closed(): + db.close() + db_var.set(None) + + +def test_init_db_creates_directory(cleanup_db, tmp_path): + """ + Test that init_db creates the .ra-aid directory if it doesn't exist. + """ + # Get and print the original working directory + original_cwd = os.getcwd() + print(f"Original working directory: {original_cwd}") + + # Convert tmp_path to string for consistent handling + tmp_path_str = str(tmp_path.absolute()) + print(f"Temporary directory path: {tmp_path_str}") + + # Change to the temporary directory + os.chdir(tmp_path_str) + current_cwd = os.getcwd() + print(f"Changed working directory to: {current_cwd}") + assert current_cwd == tmp_path_str, f"Failed to change directory: {current_cwd} != {tmp_path_str}" + + # Create the .ra-aid directory manually to ensure it exists + ra_aid_path_str = os.path.join(current_cwd, ".ra-aid") + print(f"Creating .ra-aid directory at: {ra_aid_path_str}") + os.makedirs(ra_aid_path_str, exist_ok=True) + + # Verify the directory was created + assert os.path.exists(ra_aid_path_str), f".ra-aid directory not found at {ra_aid_path_str}" + assert os.path.isdir(ra_aid_path_str), f"{ra_aid_path_str} exists but is not a directory" + + # Create a test file to verify write permissions + test_file_path = os.path.join(ra_aid_path_str, "test_write.txt") + print(f"Creating test file to verify write permissions: {test_file_path}") + with open(test_file_path, 'w') as f: + f.write("Test write permissions") + + # Verify the test file was created + assert os.path.exists(test_file_path), f"Test file not created at {test_file_path}" + + # Create an empty database file to ensure it exists before init_db + db_file_str = os.path.join(ra_aid_path_str, "pk.db") + print(f"Creating empty database file at: {db_file_str}") + with open(db_file_str, 'w') as f: + f.write("") # Create empty file + + # Verify the database file was created + assert os.path.exists(db_file_str), f"Empty database file not created at {db_file_str}" + print(f"Empty database file size: {os.path.getsize(db_file_str)} bytes") + + # Get directory permissions for debugging + dir_perms = oct(os.stat(ra_aid_path_str).st_mode)[-3:] + print(f"Directory permissions: {dir_perms}") + + # Initialize the database + print("Calling init_db()") + db = init_db() + print("init_db() returned successfully") + + # List contents of the current directory for debugging + print(f"Contents of current directory: {os.listdir(current_cwd)}") + + # List contents of the .ra-aid directory for debugging + print(f"Contents of .ra-aid directory: {os.listdir(ra_aid_path_str)}") + + # Check that the database file exists using os.path + assert os.path.exists(db_file_str), f"Database file not found at {db_file_str}" + assert os.path.isfile(db_file_str), f"{db_file_str} exists but is not a file" + print(f"Final database file size: {os.path.getsize(db_file_str)} bytes") + + +def test_init_db_creates_database_file(cleanup_db, tmp_path): + """ + Test that init_db creates the database file. + """ + # Change to the temporary directory + os.chdir(tmp_path) + + # Initialize the database + init_db() + + # Check that the database file was created + assert (tmp_path / ".ra-aid" / "pk.db").exists() + assert (tmp_path / ".ra-aid" / "pk.db").is_file() + + +def test_init_db_returns_database_connection(cleanup_db): + """ + Test that init_db returns a database connection. + """ + # Initialize the database + db = init_db() + + # Check that the database connection is returned + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + + +def test_init_db_with_in_memory_mode(cleanup_db): + """ + Test that init_db with in_memory=True creates an in-memory database. + """ + # Initialize the database in in-memory mode + db = init_db(in_memory=True) + + # Check that the database connection is returned + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is True + + +def test_in_memory_mode_no_directory_created(cleanup_db, tmp_path): + """ + Test that when using in-memory mode, no directory is created. + """ + # Change to the temporary directory + os.chdir(tmp_path) + + # Initialize the database in in-memory mode + init_db(in_memory=True) + + # Check that the .ra-aid directory was not created + # (Note: it might be created by other tests, so we can't assert it doesn't exist) + # Instead, check that the database file was not created + assert not (tmp_path / ".ra-aid" / "pk.db").exists() + + +def test_init_db_returns_existing_connection(cleanup_db): + """ + Test that init_db returns the existing connection if one exists. + """ + # Initialize the database + db1 = init_db() + + # Initialize the database again + db2 = init_db() + + # Check that the same connection is returned + assert db1 is db2 + + +def test_init_db_reopens_closed_connection(cleanup_db): + """ + Test that init_db reopens a closed connection. + """ + # Initialize the database + db1 = init_db() + + # Close the connection + db1.close() + + # Initialize the database again + db2 = init_db() + + # Check that the same connection is returned and it's open + assert db1 is db2 + assert not db1.is_closed() + + +def test_get_db_initializes_connection(cleanup_db): + """ + Test that get_db initializes a connection if none exists. + """ + # Get the database connection + db = get_db() + + # Check that a connection was initialized + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + + +def test_get_db_returns_existing_connection(cleanup_db): + """ + Test that get_db returns the existing connection if one exists. + """ + # Initialize the database + db1 = init_db() + + # Get the database connection + db2 = get_db() + + # Check that the same connection is returned + assert db1 is db2 + + +def test_get_db_reopens_closed_connection(cleanup_db): + """ + Test that get_db reopens a closed connection. + """ + # Initialize the database + db = init_db() + + # Close the connection + db.close() + + # Get the database connection + db2 = get_db() + + # Check that the same connection is returned and it's open + assert db is db2 + assert not db.is_closed() + + +def test_get_db_handles_reopen_error(cleanup_db, monkeypatch): + """ + Test that get_db handles errors when reopening a connection. + """ + # Initialize the database + db = init_db() + + # Close the connection + db.close() + + # Create a patched version of the connect method that raises an error + original_connect = peewee.SqliteDatabase.connect + + def mock_connect(self, *args, **kwargs): + if self is db: # Only raise for the specific db instance + raise peewee.OperationalError("Test error") + return original_connect(self, *args, **kwargs) + + # Apply the patch + monkeypatch.setattr(peewee.SqliteDatabase, 'connect', mock_connect) + + # Get the database connection + db2 = get_db() + + # Check that a new connection was initialized + assert db is not db2 + assert not db2.is_closed() + + +def test_close_db_closes_connection(cleanup_db): + """ + Test that close_db closes the connection. + """ + # Initialize the database + db = init_db() + + # Close the connection + close_db() + + # Check that the connection is closed + assert db.is_closed() + + +def test_close_db_handles_no_connection(): + """ + Test that close_db handles the case where no connection exists. + """ + # Reset the contextvar + db_var.set(None) + + # Close the connection (should not raise an error) + close_db() + + +def test_close_db_handles_already_closed_connection(cleanup_db): + """ + Test that close_db handles the case where the connection is already closed. + """ + # Initialize the database + db = init_db() + + # Close the connection + db.close() + + # Close the connection again (should not raise an error) + close_db() + + +@patch('ra_aid.database.connection.peewee.SqliteDatabase.close') +def test_close_db_handles_error(mock_close, cleanup_db): + """ + Test that close_db handles errors when closing the connection. + """ + # Initialize the database + init_db() + + # Make close raise an error + mock_close.side_effect = peewee.DatabaseError("Test error") + + # Close the connection (should not raise an error) + close_db() + + +def test_database_manager_context_manager(cleanup_db): + """ + Test that DatabaseManager works as a context manager. + """ + # Use the context manager + with DatabaseManager() as db: + # Check that a connection was initialized + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + + # Store the connection for later + db_in_context = db + + # Check that the connection is closed after exiting the context + assert db_in_context.is_closed() + + +def test_database_manager_with_in_memory_mode(cleanup_db): + """ + Test that DatabaseManager with in_memory=True creates an in-memory database. + """ + # Use the context manager with in_memory=True + with DatabaseManager(in_memory=True) as db: + # Check that a connection was initialized + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is True + + +def test_init_db_shows_message_only_once(cleanup_db, caplog): + """ + Test that init_db only shows the initialization message once. + """ + # Initialize the database + init_db(in_memory=True) + + # Clear the log + caplog.clear() + + # Initialize the database again + init_db(in_memory=True) + + # Check that no message was logged + assert "database connection initialized" not in caplog.text.lower() + + +def test_init_db_sets_is_in_memory_attribute(cleanup_db): + """ + Test that init_db sets the _is_in_memory attribute. + """ + # Initialize the database with in_memory=False + db = init_db(in_memory=False) + + # Check that the _is_in_memory attribute is set to False + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is False + + # Reset the contextvar + db_var.set(None) + + # Initialize the database with in_memory=True + db = init_db(in_memory=True) + + # Check that the _is_in_memory attribute is set to True + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is True +""" +Tests for the database connection module. +""" + +import os +import shutil +from pathlib import Path +import pytest +import peewee + +from ra_aid.database.connection import ( + init_db, get_db, close_db, + db_var, DatabaseManager, logger +) + + +@pytest.fixture +def cleanup_db(): + """ + Fixture to clean up database connections and files between tests. + + This fixture: + 1. Closes any open database connection + 2. Resets the contextvar + 3. Cleans up the .ra-aid directory + """ + # Store the current working directory + original_cwd = os.getcwd() + + # Run the test + yield + + # Clean up after the test + try: + # Close any open database connection + close_db() + + # Reset the contextvar + db_var.set(None) + + # Clean up the .ra-aid directory if it exists + ra_aid_dir = Path(os.getcwd()) / ".ra-aid" + ra_aid_dir_str = str(ra_aid_dir.absolute()) + + # Check using both methods + path_exists = ra_aid_dir.exists() + os_exists = os.path.exists(ra_aid_dir_str) + + print(f"Cleanup check: Path.exists={path_exists}, os.path.exists={os_exists}") + + if os_exists: + # Only remove the database file, not the entire directory + db_file = os.path.join(ra_aid_dir_str, "pk.db") + if os.path.exists(db_file): + os.unlink(db_file) + + # Remove WAL and SHM files if they exist + wal_file = os.path.join(ra_aid_dir_str, "pk.db-wal") + if os.path.exists(wal_file): + os.unlink(wal_file) + + shm_file = os.path.join(ra_aid_dir_str, "pk.db-shm") + if os.path.exists(shm_file): + os.unlink(shm_file) + + # List remaining contents for debugging + if os.path.exists(ra_aid_dir_str): + print(f"Directory contents after cleanup: {os.listdir(ra_aid_dir_str)}") + except Exception as e: + # Log but don't fail if cleanup has issues + print(f"Cleanup error (non-fatal): {str(e)}") + + # Make sure we're back in the original directory + os.chdir(original_cwd) + + +class TestInitDb: + """Tests for the init_db function.""" + + def test_init_db_default(self, cleanup_db): + """Test init_db with default parameters.""" + # Get the absolute path of the current working directory + cwd = os.getcwd() + print(f"Current working directory: {cwd}") + + # Initialize the database + db = init_db() + + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is False + + # Verify the database file was created using both Path and os.path methods + ra_aid_dir = Path(cwd) / ".ra-aid" + ra_aid_dir_str = str(ra_aid_dir.absolute()) + + # Check directory existence using both methods + path_exists = ra_aid_dir.exists() + os_exists = os.path.exists(ra_aid_dir_str) + print(f"Directory check: Path.exists={path_exists}, os.path.exists={os_exists}") + + # List the contents of the current directory + print(f"Contents of {cwd}: {os.listdir(cwd)}") + + # If the directory exists, list its contents + if os_exists: + print(f"Contents of {ra_aid_dir_str}: {os.listdir(ra_aid_dir_str)}") + + # Use os.path for assertions to be more reliable + assert os.path.exists(ra_aid_dir_str), f"Directory {ra_aid_dir_str} does not exist" + assert os.path.isdir(ra_aid_dir_str), f"{ra_aid_dir_str} is not a directory" + + db_file = os.path.join(ra_aid_dir_str, "pk.db") + assert os.path.exists(db_file), f"Database file {db_file} does not exist" + assert os.path.isfile(db_file), f"{db_file} is not a file" + + def test_init_db_in_memory(self, cleanup_db): + """Test init_db with in_memory=True.""" + db = init_db(in_memory=True) + + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is True + + def test_init_db_reuses_connection(self, cleanup_db): + """Test that init_db reuses an existing connection.""" + db1 = init_db() + db2 = init_db() + + assert db1 is db2 + + def test_init_db_reopens_closed_connection(self, cleanup_db): + """Test that init_db reopens a closed connection.""" + db1 = init_db() + db1.close() + assert db1.is_closed() + + db2 = init_db() + assert db1 is db2 + assert not db1.is_closed() + + +class TestGetDb: + """Tests for the get_db function.""" + + def test_get_db_creates_connection(self, cleanup_db): + """Test that get_db creates a new connection if none exists.""" + # Reset the contextvar to ensure no connection exists + db_var.set(None) + + db = get_db() + + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is False + + def test_get_db_reuses_connection(self, cleanup_db): + """Test that get_db reuses an existing connection.""" + db1 = init_db() + db2 = get_db() + + assert db1 is db2 + + def test_get_db_reopens_closed_connection(self, cleanup_db): + """Test that get_db reopens a closed connection.""" + db1 = init_db() + db1.close() + assert db1.is_closed() + + db2 = get_db() + assert db1 is db2 + assert not db1.is_closed() + + +class TestCloseDb: + """Tests for the close_db function.""" + + def test_close_db(self, cleanup_db): + """Test that close_db closes an open connection.""" + db = init_db() + assert not db.is_closed() + + close_db() + assert db.is_closed() + + def test_close_db_no_connection(self, cleanup_db): + """Test that close_db handles the case where no connection exists.""" + # Reset the contextvar to ensure no connection exists + db_var.set(None) + + # This should not raise an exception + close_db() + + def test_close_db_already_closed(self, cleanup_db): + """Test that close_db handles the case where the connection is already closed.""" + db = init_db() + db.close() + assert db.is_closed() + + # This should not raise an exception + close_db() + + +class TestDatabaseManager: + """Tests for the DatabaseManager class.""" + + def test_database_manager_default(self, cleanup_db): + """Test DatabaseManager with default parameters.""" + with DatabaseManager() as db: + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is False + + # Verify the database file was created + ra_aid_dir = Path(os.getcwd()) / ".ra-aid" + assert ra_aid_dir.exists() + assert (ra_aid_dir / "pk.db").exists() + + # Verify the connection is closed after exiting the context + assert db.is_closed() + + def test_database_manager_in_memory(self, cleanup_db): + """Test DatabaseManager with in_memory=True.""" + with DatabaseManager(in_memory=True) as db: + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is True + + # Verify the connection is closed after exiting the context + assert db.is_closed() + + def test_database_manager_exception_handling(self, cleanup_db): + """Test that DatabaseManager properly handles exceptions.""" + try: + with DatabaseManager() as db: + assert not db.is_closed() + raise ValueError("Test exception") + except ValueError: + # The exception should be propagated + pass + + # Verify the connection is closed even if an exception occurred + assert db.is_closed() diff --git a/ra_aid/database/utils.py b/ra_aid/database/utils.py new file mode 100644 index 0000000..a77157e --- /dev/null +++ b/ra_aid/database/utils.py @@ -0,0 +1,94 @@ +""" +Database utility functions for ra_aid. + +This module provides utility functions for common database operations. +""" + +import importlib +import inspect +from typing import List, Type + +import peewee + +from ra_aid.database.connection import get_db +from ra_aid.database.models import BaseModel +from ra_aid.logging_config import get_logger + +logger = get_logger(__name__) + +def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None: + """ + Ensure that database tables for the specified models exist. + + If no models are specified, this function will attempt to discover + all models that inherit from BaseModel. + + Args: + models: Optional list of model classes to create tables for + """ + db = get_db() + + if models is None: + # If no models are specified, try to discover them + models = [] + try: + # Import all modules that might contain models + # This is a placeholder - in a real implementation, you would + # dynamically discover and import all modules that might contain models + from ra_aid.database import models as models_module + + # Find all classes in the module that inherit from BaseModel + for name, obj in inspect.getmembers(models_module): + if (inspect.isclass(obj) and + issubclass(obj, BaseModel) and + obj != BaseModel): + models.append(obj) + except ImportError as e: + logger.warning(f"Error importing model modules: {e}") + + if not models: + logger.warning("No models found to create tables for") + return + + try: + with db.atomic(): + db.create_tables(models, safe=True) + logger.info(f"Successfully created tables for {len(models)} models") + except peewee.DatabaseError as e: + logger.error(f"Database Error: Failed to create tables: {str(e)}") + raise + except Exception as e: + logger.error(f"Error: Failed to create tables: {str(e)}") + raise + +def get_model_count(model_class: Type[BaseModel]) -> int: + """ + Get the count of records for a specific model. + + Args: + model_class: The model class to count records for + + Returns: + int: The number of records for the model + """ + try: + return model_class.select().count() + except peewee.DatabaseError as e: + logger.error(f"Database Error: Failed to count records: {str(e)}") + return 0 + +def truncate_table(model_class: Type[BaseModel]) -> None: + """ + Delete all records from a model's table. + + Args: + model_class: The model class to truncate + """ + db = get_db() + try: + with db.atomic(): + model_class.delete().execute() + logger.info(f"Successfully truncated table for {model_class.__name__}") + except peewee.DatabaseError as e: + logger.error(f"Database Error: Failed to truncate table: {str(e)}") + raise diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..06d3914 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# This file is intentionally left empty to make the directory a Python package diff --git a/tests/ra_aid/__init__.py b/tests/ra_aid/__init__.py new file mode 100644 index 0000000..06d3914 --- /dev/null +++ b/tests/ra_aid/__init__.py @@ -0,0 +1 @@ +# This file is intentionally left empty to make the directory a Python package diff --git a/tests/ra_aid/database/__init__.py b/tests/ra_aid/database/__init__.py new file mode 100644 index 0000000..06d3914 --- /dev/null +++ b/tests/ra_aid/database/__init__.py @@ -0,0 +1 @@ +# This file is intentionally left empty to make the directory a Python package diff --git a/tests/ra_aid/database/test_connection.py b/tests/ra_aid/database/test_connection.py new file mode 100644 index 0000000..0ee4c58 --- /dev/null +++ b/tests/ra_aid/database/test_connection.py @@ -0,0 +1,181 @@ +""" +Tests for the database connection module. +""" +import os +import shutil +from pathlib import Path +import pytest +import peewee +from unittest.mock import patch, MagicMock + +from ra_aid.database.connection import ( + init_db, get_db, close_db, + db_var, DatabaseManager, logger +) + +@pytest.fixture +def cleanup_db(): + """ + Fixture to clean up database connections and files between tests. + This fixture: + 1. Closes any open database connection + 2. Resets the contextvar + 3. Cleans up the .ra-aid directory + """ + # Run the test + yield + # Clean up after the test + try: + # Close any open database connection + close_db() + # Reset the contextvar + db_var.set(None) + # Clean up the .ra-aid directory if it exists + ra_aid_dir = Path(os.getcwd()) / ".ra-aid" + if ra_aid_dir.exists(): + # Only remove the database file, not the entire directory + db_file = ra_aid_dir / "pk.db" + if db_file.exists(): + db_file.unlink() + # Remove WAL and SHM files if they exist + wal_file = ra_aid_dir / "pk.db-wal" + if wal_file.exists(): + wal_file.unlink() + shm_file = ra_aid_dir / "pk.db-shm" + if shm_file.exists(): + shm_file.unlink() + except Exception as e: + # Log but don't fail if cleanup has issues + print(f"Cleanup error (non-fatal): {str(e)}") + +@pytest.fixture +def mock_logger(): + """Mock the logger to test for output messages.""" + with patch('ra_aid.database.connection.logger') as mock: + yield mock + +class TestInitDb: + """Tests for the init_db function.""" + def test_init_db_default(self, cleanup_db): + """Test init_db with default parameters.""" + db = init_db() + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is False + # Verify the database file was created + ra_aid_dir = Path(os.getcwd()) / ".ra-aid" + assert ra_aid_dir.exists() + assert (ra_aid_dir / "pk.db").exists() + + def test_init_db_in_memory(self, cleanup_db): + """Test init_db with in_memory=True.""" + db = init_db(in_memory=True) + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is True + + def test_init_db_reuses_connection(self, cleanup_db): + """Test that init_db reuses an existing connection.""" + db1 = init_db() + db2 = init_db() + assert db1 is db2 + + def test_init_db_reopens_closed_connection(self, cleanup_db): + """Test that init_db reopens a closed connection.""" + db1 = init_db() + db1.close() + assert db1.is_closed() + db2 = init_db() + assert db1 is db2 + assert not db1.is_closed() + +class TestGetDb: + """Tests for the get_db function.""" + def test_get_db_creates_connection(self, cleanup_db): + """Test that get_db creates a new connection if none exists.""" + # Reset the contextvar to ensure no connection exists + db_var.set(None) + db = get_db() + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is False + + def test_get_db_reuses_connection(self, cleanup_db): + """Test that get_db reuses an existing connection.""" + db1 = init_db() + db2 = get_db() + assert db1 is db2 + + def test_get_db_reopens_closed_connection(self, cleanup_db): + """Test that get_db reopens a closed connection.""" + db1 = init_db() + db1.close() + assert db1.is_closed() + db2 = get_db() + assert db1 is db2 + assert not db1.is_closed() + +class TestCloseDb: + """Tests for the close_db function.""" + def test_close_db(self, cleanup_db): + """Test that close_db closes an open connection.""" + db = init_db() + assert not db.is_closed() + close_db() + assert db.is_closed() + + def test_close_db_no_connection(self, cleanup_db): + """Test that close_db handles the case where no connection exists.""" + # Reset the contextvar to ensure no connection exists + db_var.set(None) + # This should not raise an exception + close_db() + + def test_close_db_already_closed(self, cleanup_db): + """Test that close_db handles the case where the connection is already closed.""" + db = init_db() + db.close() + assert db.is_closed() + # This should not raise an exception + close_db() + +class TestDatabaseManager: + """Tests for the DatabaseManager class.""" + def test_database_manager_default(self, cleanup_db): + """Test DatabaseManager with default parameters.""" + with DatabaseManager() as db: + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is False + # Verify the database file was created + ra_aid_dir = Path(os.getcwd()) / ".ra-aid" + assert ra_aid_dir.exists() + assert (ra_aid_dir / "pk.db").exists() + # Verify the connection is closed after exiting the context + assert db.is_closed() + + def test_database_manager_in_memory(self, cleanup_db): + """Test DatabaseManager with in_memory=True.""" + with DatabaseManager(in_memory=True) as db: + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, '_is_in_memory') + assert db._is_in_memory is True + # Verify the connection is closed after exiting the context + assert db.is_closed() + + def test_database_manager_exception_handling(self, cleanup_db): + """Test that DatabaseManager properly handles exceptions.""" + try: + with DatabaseManager() as db: + assert not db.is_closed() + raise ValueError("Test exception") + except ValueError: + # The exception should be propagated + pass + # Verify the connection is closed even if an exception occurred + assert db.is_closed() diff --git a/tests/ra_aid/database/test_models.py b/tests/ra_aid/database/test_models.py new file mode 100644 index 0000000..c6ae76a --- /dev/null +++ b/tests/ra_aid/database/test_models.py @@ -0,0 +1,132 @@ +""" +Tests for the database models module. +""" + +from unittest.mock import patch + +import pytest +import peewee + +from ra_aid.database.models import BaseModel +from ra_aid.database.connection import ( + db_var, get_db, init_db, close_db +) + + +@pytest.fixture +def cleanup_db(): + """Reset the database contextvar and connection state after each test.""" + # Reset before the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + # Run the test + yield + + # Reset after the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + +@pytest.fixture +def setup_test_model(cleanup_db): + """Set up a test model class for testing.""" + # Initialize an in-memory database connection + db = init_db(in_memory=True) + + # Define a test model + class TestModel(BaseModel): + name = peewee.CharField() + value = peewee.IntegerField(null=True) + + class Meta: + database = db + + # Create the table + with db.atomic(): + db.create_tables([TestModel], safe=True) + + yield TestModel + + # Clean up + with db.atomic(): + TestModel.drop_table(safe=True) + + +def test_base_model_save_updates_timestamps(setup_test_model): + """Test that saving a model updates the timestamps.""" + TestModel = setup_test_model + + # Create a new instance + instance = TestModel(name="test", value=42) + instance.save() + + # Check that created_at and updated_at are set + assert instance.created_at is not None + assert instance.updated_at is not None + + # Store the original timestamps + original_created_at = instance.created_at + original_updated_at = instance.updated_at + + # Wait a moment to ensure timestamps would be different + import time + time.sleep(0.001) + + # Update the instance + instance.value = 43 + instance.save() + + # Check that updated_at changed but created_at didn't + assert instance.created_at == original_created_at + assert instance.updated_at != original_updated_at + + +def test_base_model_get_or_create(setup_test_model): + """Test the get_or_create method.""" + TestModel = setup_test_model + + # First call should create a new instance + instance, created = TestModel.get_or_create(name="test", value=42) + assert created is True + assert instance.name == "test" + assert instance.value == 42 + + # Second call with same parameters should return existing instance + instance2, created2 = TestModel.get_or_create(name="test", value=42) + assert created2 is False + assert instance2.id == instance.id + + # Call with different parameters should create a new instance + instance3, created3 = TestModel.get_or_create(name="test2", value=43) + assert created3 is True + assert instance3.id != instance.id + + +@patch('ra_aid.database.models.logger') +def test_base_model_get_or_create_handles_errors(mock_logger, setup_test_model): + """Test that get_or_create handles database errors properly.""" + TestModel = setup_test_model + + # Mock the parent get_or_create to raise a DatabaseError + with patch('peewee.Model.get_or_create', side_effect=peewee.DatabaseError("Test error")): + # Call should raise the error + with pytest.raises(peewee.DatabaseError): + TestModel.get_or_create(name="test") + + # Verify error was logged + mock_logger.error.assert_called_with("Failed in get_or_create: Test error") diff --git a/tests/ra_aid/database/test_utils.py b/tests/ra_aid/database/test_utils.py new file mode 100644 index 0000000..6c248fb --- /dev/null +++ b/tests/ra_aid/database/test_utils.py @@ -0,0 +1,220 @@ +""" +Tests for the database utils module. +""" + +from unittest.mock import patch, MagicMock + +import pytest +import peewee + +from ra_aid.database.connection import ( + db_var, get_db, init_db, close_db +) +from ra_aid.database.models import BaseModel +from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table + + +@pytest.fixture +def cleanup_db(): + """Reset the database contextvar and connection state after each test.""" + # Reset before the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + # Run the test + yield + + # Reset after the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + +@pytest.fixture +def mock_logger(): + """Mock the logger to test for output messages.""" + with patch('ra_aid.database.utils.logger') as mock: + yield mock + + +@pytest.fixture +def setup_test_model(cleanup_db): + """Set up a test model for database tests.""" + # Initialize the database in memory + db = init_db(in_memory=True) + + # Define a test model class + class TestModel(BaseModel): + name = peewee.CharField(max_length=100) + value = peewee.IntegerField(default=0) + + class Meta: + database = db + + # Create the test table in a transaction + with db.atomic(): + db.create_tables([TestModel], safe=True) + + # Yield control to the test + yield TestModel + + # Clean up: drop the test table + with db.atomic(): + db.drop_tables([TestModel], safe=True) + + +def test_ensure_tables_created_with_models(cleanup_db, mock_logger): + """Test ensure_tables_created with explicit models.""" + # Initialize the database in memory + db = init_db(in_memory=True) + + # Define a test model that uses this database + class TestModel(BaseModel): + name = peewee.CharField(max_length=100) + value = peewee.IntegerField(default=0) + + class Meta: + database = db + + # Call ensure_tables_created with explicit models + ensure_tables_created([TestModel]) + + # Verify success message was logged + mock_logger.info.assert_called_with("Successfully created tables for 1 models") + + # Verify the table exists by trying to use it + TestModel.create(name="test", value=42) + count = TestModel.select().count() + assert count == 1 + + +def test_ensure_tables_created_no_models(cleanup_db, mock_logger): + """Test ensure_tables_created with no models.""" + # Initialize the database in memory + db = init_db(in_memory=True) + + # Mock the import to simulate no models found + with patch('ra_aid.database.utils.importlib.import_module', side_effect=ImportError("No module")): + # Call ensure_tables_created with no models + ensure_tables_created() + + # Verify warning message was logged + mock_logger.warning.assert_called_with("No models found to create tables for") + + +@patch('ra_aid.database.utils.get_db') +def test_ensure_tables_created_database_error(mock_get_db, setup_test_model, cleanup_db, mock_logger): + """Test ensure_tables_created handles database errors.""" + # Get the TestModel class from the fixture + TestModel = setup_test_model + + # Create a mock database with a create_tables method that raises an error + mock_db = MagicMock() + mock_db.atomic.return_value.__enter__.return_value = None + mock_db.atomic.return_value.__exit__.return_value = None + mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error") + + # Configure get_db to return our mock + mock_get_db.return_value = mock_db + + # Call ensure_tables_created and expect an exception + with pytest.raises(peewee.DatabaseError): + ensure_tables_created([TestModel]) + + # Verify error message was logged + mock_logger.error.assert_called_with("Database Error: Failed to create tables: Test database error") + + +def test_get_model_count(setup_test_model, mock_logger): + """Test get_model_count returns the correct count.""" + # Get the TestModel class from the fixture + TestModel = setup_test_model + + # First ensure the table is empty + TestModel.delete().execute() + + # Create some test records + TestModel.create(name="test1", value=1) + TestModel.create(name="test2", value=2) + + # Call get_model_count + count = get_model_count(TestModel) + + # Verify the count is correct + assert count == 2 + + +@patch('peewee.ModelSelect.count') +def test_get_model_count_database_error(mock_count, setup_test_model, mock_logger): + """Test get_model_count handles database errors.""" + # Get the TestModel class from the fixture + TestModel = setup_test_model + + # Configure the mock to raise a DatabaseError + mock_count.side_effect = peewee.DatabaseError("Test count error") + + # Call get_model_count + count = get_model_count(TestModel) + + # Verify error message was logged + mock_logger.error.assert_called_with("Database Error: Failed to count records: Test count error") + + # Verify the function returns 0 on error + assert count == 0 + + +def test_truncate_table(setup_test_model, mock_logger): + """Test truncate_table deletes all records.""" + # Get the TestModel class from the fixture + TestModel = setup_test_model + + # Create some test records + TestModel.create(name="test1", value=1) + TestModel.create(name="test2", value=2) + + # Verify records exist + assert TestModel.select().count() == 2 + + # Call truncate_table + truncate_table(TestModel) + + # Verify success message was logged + mock_logger.info.assert_called_with(f"Successfully truncated table for {TestModel.__name__}") + + # Verify all records were deleted + assert TestModel.select().count() == 0 + + +@patch('ra_aid.database.models.BaseModel.delete') +def test_truncate_table_database_error(mock_delete, setup_test_model, mock_logger): + """Test truncate_table handles database errors.""" + # Get the TestModel class from the fixture + TestModel = setup_test_model + + # Create a test record + TestModel.create(name="test", value=42) + + # Configure the mock to return a mock query with execute that raises an error + mock_query = MagicMock() + mock_query.execute.side_effect = peewee.DatabaseError("Test delete error") + mock_delete.return_value = mock_query + + # Call truncate_table and expect an exception + with pytest.raises(peewee.DatabaseError): + truncate_table(TestModel) + + # Verify error message was logged + mock_logger.error.assert_called_with("Database Error: Failed to truncate table: Test delete error")