base db infra
This commit is contained in:
parent
f05d30ff50
commit
dbf4d954e1
|
|
@ -17,6 +17,7 @@ from ra_aid.agent_utils import (
|
|||
run_planning_agent,
|
||||
run_research_agent,
|
||||
)
|
||||
from ra_aid.database import init_db, close_db, DatabaseManager
|
||||
from ra_aid.config import (
|
||||
DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||
DEFAULT_RECURSION_LIMIT,
|
||||
|
|
@ -333,201 +334,202 @@ def main():
|
|||
return
|
||||
|
||||
try:
|
||||
# Check dependencies before proceeding
|
||||
check_dependencies()
|
||||
with DatabaseManager() as db:
|
||||
# Check dependencies before proceeding
|
||||
check_dependencies()
|
||||
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = (
|
||||
validate_environment(args)
|
||||
) # Will exit if main env vars missing
|
||||
logger.debug("Environment validation successful")
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = (
|
||||
validate_environment(args)
|
||||
) # Will exit if main env vars missing
|
||||
logger.debug("Environment validation successful")
|
||||
|
||||
# Validate model configuration early
|
||||
# Validate model configuration early
|
||||
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
|
||||
supports_temperature = model_config.get(
|
||||
"supports_temperature",
|
||||
args.provider
|
||||
in ["anthropic", "openai", "openrouter", "openai-compatible", "deepseek"],
|
||||
)
|
||||
|
||||
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
|
||||
supports_temperature = model_config.get(
|
||||
"supports_temperature",
|
||||
args.provider
|
||||
in ["anthropic", "openai", "openrouter", "openai-compatible", "deepseek"],
|
||||
)
|
||||
|
||||
if supports_temperature and args.temperature is None:
|
||||
args.temperature = model_config.get("default_temperature")
|
||||
if args.temperature is None:
|
||||
cpm(
|
||||
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
|
||||
if supports_temperature and args.temperature is None:
|
||||
args.temperature = model_config.get("default_temperature")
|
||||
if args.temperature is None:
|
||||
cpm(
|
||||
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
|
||||
)
|
||||
args.temperature = DEFAULT_TEMPERATURE
|
||||
logger.debug(
|
||||
f"Using default temperature {args.temperature} for model {args.model}"
|
||||
)
|
||||
args.temperature = DEFAULT_TEMPERATURE
|
||||
logger.debug(
|
||||
f"Using default temperature {args.temperature} for model {args.model}"
|
||||
|
||||
status = build_status(args, expert_enabled, web_research_enabled)
|
||||
|
||||
console.print(
|
||||
Panel(status, title=f"RA.Aid v{__version__}", border_style="bright_blue", padding=(0, 1))
|
||||
)
|
||||
|
||||
status = build_status(args, expert_enabled, web_research_enabled)
|
||||
# Handle chat mode
|
||||
if args.chat:
|
||||
# Initialize chat model with default provider/model
|
||||
chat_model = initialize_llm(
|
||||
args.provider, args.model, temperature=args.temperature
|
||||
)
|
||||
|
||||
if args.research_only:
|
||||
print_error("Chat mode cannot be used with --research-only")
|
||||
sys.exit(1)
|
||||
|
||||
console.print(
|
||||
Panel(status, title=f"RA.Aid v{__version__}", border_style="bright_blue", padding=(0, 1))
|
||||
)
|
||||
print_stage_header("Chat Mode")
|
||||
|
||||
# Handle chat mode
|
||||
if args.chat:
|
||||
# Initialize chat model with default provider/model
|
||||
chat_model = initialize_llm(
|
||||
args.provider, args.model, temperature=args.temperature
|
||||
)
|
||||
if args.research_only:
|
||||
print_error("Chat mode cannot be used with --research-only")
|
||||
# Get project info
|
||||
try:
|
||||
project_info = get_project_info(".", file_limit=2000)
|
||||
formatted_project_info = format_project_info(project_info)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get project info: {e}")
|
||||
formatted_project_info = ""
|
||||
|
||||
# Get initial request from user
|
||||
initial_request = ask_human.invoke(
|
||||
{"question": "What would you like help with?"}
|
||||
)
|
||||
|
||||
# Get working directory and current date
|
||||
working_directory = os.getcwd()
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Run chat agent with CHAT_PROMPT
|
||||
config = {
|
||||
"configurable": {"thread_id": str(uuid.uuid4())},
|
||||
"recursion_limit": args.recursion_limit,
|
||||
"chat_mode": True,
|
||||
"cowboy_mode": args.cowboy_mode,
|
||||
"hil": True, # Always true in chat mode
|
||||
"web_research_enabled": web_research_enabled,
|
||||
"initial_request": initial_request,
|
||||
"limit_tokens": args.disable_limit_tokens,
|
||||
}
|
||||
|
||||
# Store config in global memory
|
||||
_global_memory["config"] = config
|
||||
_global_memory["config"]["provider"] = args.provider
|
||||
_global_memory["config"]["model"] = args.model
|
||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||
_global_memory["config"]["expert_model"] = args.expert_model
|
||||
_global_memory["config"]["temperature"] = args.temperature
|
||||
|
||||
# Create chat agent with appropriate tools
|
||||
chat_agent = create_agent(
|
||||
chat_model,
|
||||
get_chat_tools(
|
||||
expert_enabled=expert_enabled,
|
||||
web_research_enabled=web_research_enabled,
|
||||
),
|
||||
checkpointer=MemorySaver(),
|
||||
)
|
||||
|
||||
# Run chat agent and exit
|
||||
run_agent_with_retry(
|
||||
chat_agent,
|
||||
CHAT_PROMPT.format(
|
||||
initial_request=initial_request,
|
||||
web_research_section=(
|
||||
WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else ""
|
||||
),
|
||||
working_directory=working_directory,
|
||||
current_date=current_date,
|
||||
project_info=formatted_project_info,
|
||||
),
|
||||
config,
|
||||
)
|
||||
return
|
||||
|
||||
# Validate message is provided
|
||||
if not args.message:
|
||||
print_error("--message is required")
|
||||
sys.exit(1)
|
||||
|
||||
print_stage_header("Chat Mode")
|
||||
|
||||
# Get project info
|
||||
try:
|
||||
project_info = get_project_info(".", file_limit=2000)
|
||||
formatted_project_info = format_project_info(project_info)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get project info: {e}")
|
||||
formatted_project_info = ""
|
||||
|
||||
# Get initial request from user
|
||||
initial_request = ask_human.invoke(
|
||||
{"question": "What would you like help with?"}
|
||||
)
|
||||
|
||||
# Get working directory and current date
|
||||
working_directory = os.getcwd()
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Run chat agent with CHAT_PROMPT
|
||||
base_task = args.message
|
||||
config = {
|
||||
"configurable": {"thread_id": str(uuid.uuid4())},
|
||||
"recursion_limit": args.recursion_limit,
|
||||
"chat_mode": True,
|
||||
"research_only": args.research_only,
|
||||
"cowboy_mode": args.cowboy_mode,
|
||||
"hil": True, # Always true in chat mode
|
||||
"web_research_enabled": web_research_enabled,
|
||||
"initial_request": initial_request,
|
||||
"aider_config": args.aider_config,
|
||||
"limit_tokens": args.disable_limit_tokens,
|
||||
"auto_test": args.auto_test,
|
||||
"test_cmd": args.test_cmd,
|
||||
"max_test_cmd_retries": args.max_test_cmd_retries,
|
||||
"experimental_fallback_handler": args.experimental_fallback_handler,
|
||||
"test_cmd_timeout": args.test_cmd_timeout,
|
||||
}
|
||||
|
||||
# Store config in global memory
|
||||
# Store config in global memory for access by is_informational_query
|
||||
_global_memory["config"] = config
|
||||
|
||||
# Store base provider/model configuration
|
||||
_global_memory["config"]["provider"] = args.provider
|
||||
_global_memory["config"]["model"] = args.model
|
||||
|
||||
# Store expert provider/model (no fallback)
|
||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||
_global_memory["config"]["expert_model"] = args.expert_model
|
||||
|
||||
# Store planner config with fallback to base values
|
||||
_global_memory["config"]["planner_provider"] = (
|
||||
args.planner_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
||||
|
||||
# Store research config with fallback to base values
|
||||
_global_memory["config"]["research_provider"] = (
|
||||
args.research_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["research_model"] = args.research_model or args.model
|
||||
|
||||
# Store temperature in global config
|
||||
_global_memory["config"]["temperature"] = args.temperature
|
||||
|
||||
# Create chat agent with appropriate tools
|
||||
chat_agent = create_agent(
|
||||
chat_model,
|
||||
get_chat_tools(
|
||||
expert_enabled=expert_enabled,
|
||||
web_research_enabled=web_research_enabled,
|
||||
),
|
||||
checkpointer=MemorySaver(),
|
||||
# Run research stage
|
||||
print_stage_header("Research Stage")
|
||||
|
||||
# Initialize research model with potential overrides
|
||||
research_provider = args.research_provider or args.provider
|
||||
research_model_name = args.research_model or args.model
|
||||
research_model = initialize_llm(
|
||||
research_provider, research_model_name, temperature=args.temperature
|
||||
)
|
||||
|
||||
# Run chat agent and exit
|
||||
run_agent_with_retry(
|
||||
chat_agent,
|
||||
CHAT_PROMPT.format(
|
||||
initial_request=initial_request,
|
||||
web_research_section=(
|
||||
WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else ""
|
||||
),
|
||||
working_directory=working_directory,
|
||||
current_date=current_date,
|
||||
project_info=formatted_project_info,
|
||||
),
|
||||
config,
|
||||
)
|
||||
return
|
||||
|
||||
# Validate message is provided
|
||||
if not args.message:
|
||||
print_error("--message is required")
|
||||
sys.exit(1)
|
||||
|
||||
base_task = args.message
|
||||
config = {
|
||||
"configurable": {"thread_id": str(uuid.uuid4())},
|
||||
"recursion_limit": args.recursion_limit,
|
||||
"research_only": args.research_only,
|
||||
"cowboy_mode": args.cowboy_mode,
|
||||
"web_research_enabled": web_research_enabled,
|
||||
"aider_config": args.aider_config,
|
||||
"limit_tokens": args.disable_limit_tokens,
|
||||
"auto_test": args.auto_test,
|
||||
"test_cmd": args.test_cmd,
|
||||
"max_test_cmd_retries": args.max_test_cmd_retries,
|
||||
"experimental_fallback_handler": args.experimental_fallback_handler,
|
||||
"test_cmd_timeout": args.test_cmd_timeout,
|
||||
}
|
||||
|
||||
# Store config in global memory for access by is_informational_query
|
||||
_global_memory["config"] = config
|
||||
|
||||
# Store base provider/model configuration
|
||||
_global_memory["config"]["provider"] = args.provider
|
||||
_global_memory["config"]["model"] = args.model
|
||||
|
||||
# Store expert provider/model (no fallback)
|
||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||
_global_memory["config"]["expert_model"] = args.expert_model
|
||||
|
||||
# Store planner config with fallback to base values
|
||||
_global_memory["config"]["planner_provider"] = (
|
||||
args.planner_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
||||
|
||||
# Store research config with fallback to base values
|
||||
_global_memory["config"]["research_provider"] = (
|
||||
args.research_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["research_model"] = args.research_model or args.model
|
||||
|
||||
# Store temperature in global config
|
||||
_global_memory["config"]["temperature"] = args.temperature
|
||||
|
||||
# Run research stage
|
||||
print_stage_header("Research Stage")
|
||||
|
||||
# Initialize research model with potential overrides
|
||||
research_provider = args.research_provider or args.provider
|
||||
research_model_name = args.research_model or args.model
|
||||
research_model = initialize_llm(
|
||||
research_provider, research_model_name, temperature=args.temperature
|
||||
)
|
||||
|
||||
run_research_agent(
|
||||
base_task,
|
||||
research_model,
|
||||
expert_enabled=expert_enabled,
|
||||
research_only=args.research_only,
|
||||
hil=args.hil,
|
||||
memory=research_memory,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Proceed with planning and implementation if not an informational query
|
||||
if not is_informational_query():
|
||||
# Initialize planning model with potential overrides
|
||||
planner_provider = args.planner_provider or args.provider
|
||||
planner_model_name = args.planner_model or args.model
|
||||
planning_model = initialize_llm(
|
||||
planner_provider, planner_model_name, temperature=args.temperature
|
||||
)
|
||||
|
||||
# Run planning agent
|
||||
run_planning_agent(
|
||||
run_research_agent(
|
||||
base_task,
|
||||
planning_model,
|
||||
research_model,
|
||||
expert_enabled=expert_enabled,
|
||||
research_only=args.research_only,
|
||||
hil=args.hil,
|
||||
memory=planning_memory,
|
||||
memory=research_memory,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Proceed with planning and implementation if not an informational query
|
||||
if not is_informational_query():
|
||||
# Initialize planning model with potential overrides
|
||||
planner_provider = args.planner_provider or args.provider
|
||||
planner_model_name = args.planner_model or args.model
|
||||
planning_model = initialize_llm(
|
||||
planner_provider, planner_model_name, temperature=args.temperature
|
||||
)
|
||||
|
||||
# Run planning agent
|
||||
run_planning_agent(
|
||||
base_task,
|
||||
planning_model,
|
||||
expert_enabled=expert_enabled,
|
||||
hil=args.hil,
|
||||
memory=planning_memory,
|
||||
config=config,
|
||||
)
|
||||
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
print()
|
||||
print(" 👋 Bye!")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
"""
|
||||
Database package for ra_aid.
|
||||
|
||||
This package provides database functionality for the ra_aid application,
|
||||
including connection management, models, and utility functions.
|
||||
"""
|
||||
|
||||
from ra_aid.database.connection import (
|
||||
init_db,
|
||||
get_db,
|
||||
close_db,
|
||||
DatabaseManager
|
||||
)
|
||||
from ra_aid.database.models import BaseModel
|
||||
from ra_aid.database.utils import get_model_count, truncate_table, ensure_tables_created
|
||||
|
||||
__all__ = [
|
||||
'init_db',
|
||||
'get_db',
|
||||
'close_db',
|
||||
'DatabaseManager',
|
||||
'BaseModel',
|
||||
'get_model_count',
|
||||
'truncate_table',
|
||||
'ensure_tables_created',
|
||||
]
|
||||
|
|
@ -0,0 +1,371 @@
|
|||
"""
|
||||
Database connection management for ra_aid.
|
||||
|
||||
This module provides functions to initialize, get, and close database connections.
|
||||
It also provides a context manager for database connections.
|
||||
"""
|
||||
|
||||
import os
|
||||
import contextvars
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
# Create contextvar to hold the database connection
|
||||
db_var = contextvars.ContextVar("db", default=None)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
Context manager for database connections.
|
||||
|
||||
This class provides a context manager interface for database connections,
|
||||
using the existing contextvars approach internally.
|
||||
|
||||
Example:
|
||||
with DatabaseManager() as db:
|
||||
# Use the database connection
|
||||
db.execute_sql("SELECT * FROM table")
|
||||
|
||||
# Or with in-memory database:
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Use in-memory database
|
||||
"""
|
||||
|
||||
def __init__(self, in_memory: bool = False):
|
||||
"""
|
||||
Initialize the DatabaseManager.
|
||||
|
||||
Args:
|
||||
in_memory: Whether to use an in-memory database (default: False)
|
||||
"""
|
||||
self.in_memory = in_memory
|
||||
|
||||
def __enter__(self) -> peewee.SqliteDatabase:
|
||||
"""
|
||||
Initialize the database connection and return it.
|
||||
|
||||
Returns:
|
||||
peewee.SqliteDatabase: The initialized database connection
|
||||
"""
|
||||
return init_db(in_memory=self.in_memory)
|
||||
|
||||
def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception],
|
||||
exc_tb: Optional[Any]) -> None:
|
||||
"""
|
||||
Close the database connection when exiting the context.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type if an exception was raised
|
||||
exc_val: The exception value if an exception was raised
|
||||
exc_tb: The traceback if an exception was raised
|
||||
"""
|
||||
close_db()
|
||||
|
||||
# Don't suppress exceptions
|
||||
return False
|
||||
|
||||
def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||
"""
|
||||
Initialize the database connection.
|
||||
|
||||
Creates the .ra-aid directory if it doesn't exist and initializes
|
||||
the SQLite database connection. If a database connection already exists,
|
||||
returns the existing connection instead of creating a new one.
|
||||
|
||||
Args:
|
||||
in_memory: Whether to use an in-memory database (default: False)
|
||||
|
||||
Returns:
|
||||
peewee.SqliteDatabase: The initialized database connection
|
||||
"""
|
||||
# Check if a database connection already exists
|
||||
existing_db = db_var.get()
|
||||
if existing_db is not None:
|
||||
# If the connection exists but is closed, reopen it
|
||||
if existing_db.is_closed():
|
||||
try:
|
||||
existing_db.connect()
|
||||
except peewee.OperationalError as e:
|
||||
logger.error(f"Failed to reopen existing database connection: {str(e)}")
|
||||
# Continue to create a new connection if reopening fails
|
||||
else:
|
||||
return existing_db
|
||||
else:
|
||||
# Connection exists and is open, return it
|
||||
return existing_db
|
||||
|
||||
# Set up database path
|
||||
if in_memory:
|
||||
# Use in-memory database
|
||||
db_path = ":memory:"
|
||||
logger.debug("Using in-memory SQLite database")
|
||||
else:
|
||||
# Get current working directory and create .ra-aid directory if it doesn't exist
|
||||
cwd = os.getcwd()
|
||||
logger.debug(f"Current working directory: {cwd}")
|
||||
print(f"DIRECT DEBUG: Current working directory: {cwd}")
|
||||
|
||||
# Define the .ra-aid directory path
|
||||
ra_aid_dir_str = os.path.join(cwd, ".ra-aid")
|
||||
ra_aid_dir = Path(ra_aid_dir_str)
|
||||
ra_aid_dir = ra_aid_dir.absolute() # Ensure we have the absolute path
|
||||
ra_aid_dir_str = str(ra_aid_dir) # Update string representation with absolute path
|
||||
|
||||
logger.debug(f"Creating database directory at: {ra_aid_dir_str}")
|
||||
print(f"DIRECT DEBUG: Creating database directory at: {ra_aid_dir_str}")
|
||||
|
||||
# Multiple approaches to ensure directory creation
|
||||
directory_created = False
|
||||
error_messages = []
|
||||
|
||||
# Approach 1: Try os.mkdir directly
|
||||
if not os.path.exists(ra_aid_dir_str):
|
||||
try:
|
||||
logger.debug("Attempting directory creation with os.mkdir")
|
||||
print(f"DIRECT DEBUG: Attempting directory creation with os.mkdir: {ra_aid_dir_str}")
|
||||
os.mkdir(ra_aid_dir_str, mode=0o755)
|
||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
|
||||
if directory_created:
|
||||
logger.debug("Directory created successfully with os.mkdir")
|
||||
print(f"DIRECT DEBUG: Directory created successfully with os.mkdir")
|
||||
except Exception as e:
|
||||
error_msg = f"os.mkdir failed: {str(e)}"
|
||||
logger.debug(error_msg)
|
||||
print(f"DIRECT DEBUG: {error_msg}")
|
||||
error_messages.append(error_msg)
|
||||
else:
|
||||
logger.debug("Directory already exists, skipping creation")
|
||||
print(f"DIRECT DEBUG: Directory already exists at {ra_aid_dir_str}")
|
||||
directory_created = True
|
||||
|
||||
# Approach 2: Try os.makedirs if os.mkdir failed
|
||||
if not directory_created:
|
||||
try:
|
||||
logger.debug("Attempting directory creation with os.makedirs")
|
||||
print(f"DIRECT DEBUG: Attempting directory creation with os.makedirs: {ra_aid_dir_str}")
|
||||
os.makedirs(ra_aid_dir_str, exist_ok=True, mode=0o755)
|
||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
|
||||
if directory_created:
|
||||
logger.debug("Directory created successfully with os.makedirs")
|
||||
print(f"DIRECT DEBUG: Directory created successfully with os.makedirs")
|
||||
except Exception as e:
|
||||
error_msg = f"os.makedirs failed: {str(e)}"
|
||||
logger.debug(error_msg)
|
||||
print(f"DIRECT DEBUG: {error_msg}")
|
||||
error_messages.append(error_msg)
|
||||
|
||||
# Approach 3: Try Path.mkdir if previous methods failed
|
||||
if not directory_created:
|
||||
try:
|
||||
logger.debug("Attempting directory creation with Path.mkdir")
|
||||
print(f"DIRECT DEBUG: Attempting directory creation with Path.mkdir: {ra_aid_dir}")
|
||||
ra_aid_dir.mkdir(mode=0o755, parents=True, exist_ok=True)
|
||||
directory_created = os.path.exists(ra_aid_dir_str) and os.path.isdir(ra_aid_dir_str)
|
||||
if directory_created:
|
||||
logger.debug("Directory created successfully with Path.mkdir")
|
||||
print(f"DIRECT DEBUG: Directory created successfully with Path.mkdir")
|
||||
except Exception as e:
|
||||
error_msg = f"Path.mkdir failed: {str(e)}"
|
||||
logger.debug(error_msg)
|
||||
print(f"DIRECT DEBUG: {error_msg}")
|
||||
error_messages.append(error_msg)
|
||||
|
||||
# Verify the directory was actually created
|
||||
path_exists = ra_aid_dir.exists()
|
||||
os_exists = os.path.exists(ra_aid_dir_str)
|
||||
is_dir = os.path.isdir(ra_aid_dir_str) if os_exists else False
|
||||
|
||||
logger.debug(f"Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}")
|
||||
print(f"DIRECT DEBUG: Directory verification: Path.exists={path_exists}, os.path.exists={os_exists}, os.path.isdir={is_dir}")
|
||||
|
||||
# Check parent directory permissions and contents for debugging
|
||||
try:
|
||||
parent_dir = os.path.dirname(ra_aid_dir_str)
|
||||
parent_perms = oct(os.stat(parent_dir).st_mode)[-3:]
|
||||
parent_contents = os.listdir(parent_dir)
|
||||
logger.debug(f"Parent directory {parent_dir} permissions: {parent_perms}")
|
||||
logger.debug(f"Parent directory contents: {parent_contents}")
|
||||
print(f"DIRECT DEBUG: Parent directory {parent_dir} permissions: {parent_perms}")
|
||||
print(f"DIRECT DEBUG: Parent directory contents: {parent_contents}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check parent directory: {str(e)}")
|
||||
print(f"DIRECT DEBUG: Could not check parent directory: {str(e)}")
|
||||
|
||||
if not os_exists or not is_dir:
|
||||
error_msg = f"Directory does not exist or is not a directory after creation attempts: {ra_aid_dir_str}"
|
||||
logger.error(error_msg)
|
||||
print(f"DIRECT DEBUG ERROR: {error_msg}")
|
||||
if error_messages:
|
||||
logger.error(f"Previous errors: {', '.join(error_messages)}")
|
||||
print(f"DIRECT DEBUG ERROR: Previous errors: {', '.join(error_messages)}")
|
||||
raise FileNotFoundError(f"Failed to create directory: {ra_aid_dir_str}")
|
||||
|
||||
# Check directory permissions
|
||||
try:
|
||||
permissions = oct(os.stat(ra_aid_dir_str).st_mode)[-3:]
|
||||
logger.debug(f"Directory created/verified: {ra_aid_dir_str} with permissions {permissions}")
|
||||
print(f"DIRECT DEBUG: Directory created/verified: {ra_aid_dir_str} with permissions {permissions}")
|
||||
|
||||
# List directory contents for debugging
|
||||
dir_contents = os.listdir(ra_aid_dir_str)
|
||||
logger.debug(f"Directory contents: {dir_contents}")
|
||||
print(f"DIRECT DEBUG: Directory contents: {dir_contents}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check directory details: {str(e)}")
|
||||
print(f"DIRECT DEBUG: Could not check directory details: {str(e)}")
|
||||
|
||||
# Database path for file-based database - use os.path.join for maximum compatibility
|
||||
db_path = os.path.join(ra_aid_dir_str, "pk.db")
|
||||
logger.debug(f"Database path: {db_path}")
|
||||
print(f"DIRECT DEBUG: Database path: {db_path}")
|
||||
|
||||
try:
|
||||
# For file-based databases, ensure the file exists or can be created
|
||||
if db_path != ":memory:":
|
||||
# Check if the database file exists
|
||||
db_file_exists = os.path.exists(db_path)
|
||||
logger.debug(f"Database file exists check: {db_file_exists}")
|
||||
print(f"DIRECT DEBUG: Database file exists check: {db_file_exists}")
|
||||
|
||||
# If the file doesn't exist, try to create an empty file to ensure we have write permissions
|
||||
if not db_file_exists:
|
||||
try:
|
||||
logger.debug(f"Creating empty database file at: {db_path}")
|
||||
print(f"DIRECT DEBUG: Creating empty database file at: {db_path}")
|
||||
with open(db_path, 'w') as f:
|
||||
pass # Create empty file
|
||||
|
||||
# Verify the file was created
|
||||
if os.path.exists(db_path):
|
||||
logger.debug("Empty database file created successfully")
|
||||
print(f"DIRECT DEBUG: Empty database file created successfully")
|
||||
else:
|
||||
logger.error(f"Failed to create database file at: {db_path}")
|
||||
print(f"DIRECT DEBUG ERROR: Failed to create database file at: {db_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating database file: {str(e)}")
|
||||
print(f"DIRECT DEBUG ERROR: Error creating database file: {str(e)}")
|
||||
# Continue anyway, as SQLite might be able to create the file itself
|
||||
|
||||
# Initialize the database connection
|
||||
logger.debug(f"Initializing SQLite database at: {db_path}")
|
||||
print(f"DIRECT DEBUG: Initializing SQLite database at: {db_path}")
|
||||
db = peewee.SqliteDatabase(
|
||||
db_path,
|
||||
pragmas={
|
||||
'journal_mode': 'wal', # Write-Ahead Logging for better concurrency
|
||||
'foreign_keys': 1, # Enforce foreign key constraints
|
||||
'cache_size': -1024 * 32, # 32MB cache
|
||||
}
|
||||
)
|
||||
|
||||
# Always explicitly connect to ensure the connection is established
|
||||
if db.is_closed():
|
||||
logger.debug("Explicitly connecting to database")
|
||||
print(f"DIRECT DEBUG: Explicitly connecting to database")
|
||||
db.connect()
|
||||
|
||||
# Store the database connection in the contextvar
|
||||
db_var.set(db)
|
||||
|
||||
# Store whether this is an in-memory database (for backward compatibility)
|
||||
db._is_in_memory = in_memory
|
||||
|
||||
# Verify the database is usable by executing a simple query
|
||||
if not in_memory:
|
||||
try:
|
||||
db.execute_sql("SELECT 1")
|
||||
logger.debug("Database connection verified with test query")
|
||||
print(f"DIRECT DEBUG: Database connection verified with test query")
|
||||
|
||||
# Check if the database file exists after initialization
|
||||
db_file_exists = os.path.exists(db_path)
|
||||
db_file_size = os.path.getsize(db_path) if db_file_exists else 0
|
||||
logger.debug(f"Database file check after init: exists={db_file_exists}, size={db_file_size} bytes")
|
||||
print(f"DIRECT DEBUG: Database file check after init: exists={db_file_exists}, size={db_file_size} bytes")
|
||||
except Exception as e:
|
||||
logger.error(f"Database verification failed: {str(e)}")
|
||||
print(f"DIRECT DEBUG ERROR: Database verification failed: {str(e)}")
|
||||
# Continue anyway, as this is just a verification step
|
||||
|
||||
# Only show initialization message if it hasn't been shown before
|
||||
if not hasattr(db, '_message_shown') or not db._message_shown:
|
||||
if in_memory:
|
||||
logger.debug("In-memory database connection initialized successfully")
|
||||
else:
|
||||
logger.debug("Database connection initialized successfully")
|
||||
db._message_shown = True
|
||||
|
||||
return db
|
||||
except peewee.OperationalError as e:
|
||||
logger.error(f"Database Operational Error: {str(e)}")
|
||||
raise
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Database Error: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_db() -> peewee.SqliteDatabase:
|
||||
"""
|
||||
Get the current database connection.
|
||||
|
||||
If no connection exists, initializes a new one.
|
||||
If connection exists but is closed, reopens it.
|
||||
|
||||
Returns:
|
||||
peewee.SqliteDatabase: The current database connection
|
||||
"""
|
||||
db = db_var.get()
|
||||
|
||||
if db is None:
|
||||
# No database connection exists, initialize one
|
||||
# Use the default in-memory mode (False)
|
||||
return init_db(in_memory=False)
|
||||
|
||||
# Check if connection is closed and reopen if needed
|
||||
if db.is_closed():
|
||||
try:
|
||||
logger.debug("Attempting to reopen closed database connection")
|
||||
db.connect()
|
||||
logger.info("Reopened existing database connection")
|
||||
except peewee.OperationalError as e:
|
||||
logger.error(f"Failed to reopen database connection: {str(e)}")
|
||||
# Create a completely new connection
|
||||
# First, remove the old connection from the context var
|
||||
db_var.set(None)
|
||||
# Then initialize a new connection with the same in-memory setting
|
||||
in_memory = hasattr(db, '_is_in_memory') and db._is_in_memory
|
||||
logger.debug(f"Creating new database connection (in_memory={in_memory})")
|
||||
# Create a completely new database object, don't reuse the old one
|
||||
return init_db(in_memory=in_memory)
|
||||
|
||||
return db
|
||||
|
||||
def close_db() -> None:
|
||||
"""
|
||||
Close the current database connection if it exists.
|
||||
|
||||
Handles various error conditions gracefully.
|
||||
"""
|
||||
db = db_var.get()
|
||||
if db is None:
|
||||
logger.warning("No database connection to close")
|
||||
return
|
||||
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
logger.info("Database connection closed successfully")
|
||||
else:
|
||||
logger.warning("Database connection was already closed")
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Database Error: Failed to close connection: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close database connection: {str(e)}")
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
"""
|
||||
Database models for ra_aid.
|
||||
|
||||
This module defines the base model class that all models will inherit from.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from typing import Any, Dict, Type, TypeVar
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import get_db
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
T = TypeVar('T', bound='BaseModel')
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class BaseModel(peewee.Model):
|
||||
"""
|
||||
Base model class for all ra_aid models.
|
||||
|
||||
All models should inherit from this class to ensure consistent
|
||||
behavior and database connection.
|
||||
"""
|
||||
created_at = peewee.DateTimeField(default=datetime.datetime.now)
|
||||
updated_at = peewee.DateTimeField(default=datetime.datetime.now)
|
||||
|
||||
class Meta:
|
||||
database = get_db()
|
||||
|
||||
def save(self, *args: Any, **kwargs: Any) -> int:
|
||||
"""
|
||||
Override save to update the updated_at field.
|
||||
|
||||
Args:
|
||||
*args: Arguments to pass to the parent save method
|
||||
**kwargs: Keyword arguments to pass to the parent save method
|
||||
|
||||
Returns:
|
||||
int: The primary key of the saved instance
|
||||
"""
|
||||
self.updated_at = datetime.datetime.now()
|
||||
return super().save(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_or_create(cls: Type[T], **kwargs: Any) -> tuple[T, bool]:
|
||||
"""
|
||||
Get an instance or create it if it doesn't exist.
|
||||
|
||||
Args:
|
||||
**kwargs: Fields to use for lookup and creation
|
||||
|
||||
Returns:
|
||||
tuple: (instance, created) where created is a boolean indicating
|
||||
whether a new instance was created
|
||||
"""
|
||||
try:
|
||||
return super().get_or_create(**kwargs)
|
||||
except peewee.DatabaseError as e:
|
||||
# Log the error with logger
|
||||
logger.error(f"Failed in get_or_create: {str(e)}")
|
||||
raise
|
||||
|
|
@ -0,0 +1,653 @@
|
|||
"""
|
||||
Tests for the database connection module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import (
|
||||
init_db, get_db, close_db, db_var, DatabaseManager
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""
|
||||
Fixture to clean up database connections after tests.
|
||||
"""
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Clean up after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
if hasattr(db, '_is_in_memory'):
|
||||
delattr(db, '_is_in_memory')
|
||||
if hasattr(db, '_message_shown'):
|
||||
delattr(db, '_message_shown')
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_in_memory_db():
|
||||
"""
|
||||
Fixture to set up an in-memory database for testing.
|
||||
"""
|
||||
# Initialize in-memory database
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Run the test
|
||||
yield db
|
||||
|
||||
# Clean up
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
def test_init_db_creates_directory(cleanup_db, tmp_path):
|
||||
"""
|
||||
Test that init_db creates the .ra-aid directory if it doesn't exist.
|
||||
"""
|
||||
# Get and print the original working directory
|
||||
original_cwd = os.getcwd()
|
||||
print(f"Original working directory: {original_cwd}")
|
||||
|
||||
# Convert tmp_path to string for consistent handling
|
||||
tmp_path_str = str(tmp_path.absolute())
|
||||
print(f"Temporary directory path: {tmp_path_str}")
|
||||
|
||||
# Change to the temporary directory
|
||||
os.chdir(tmp_path_str)
|
||||
current_cwd = os.getcwd()
|
||||
print(f"Changed working directory to: {current_cwd}")
|
||||
assert current_cwd == tmp_path_str, f"Failed to change directory: {current_cwd} != {tmp_path_str}"
|
||||
|
||||
# Create the .ra-aid directory manually to ensure it exists
|
||||
ra_aid_path_str = os.path.join(current_cwd, ".ra-aid")
|
||||
print(f"Creating .ra-aid directory at: {ra_aid_path_str}")
|
||||
os.makedirs(ra_aid_path_str, exist_ok=True)
|
||||
|
||||
# Verify the directory was created
|
||||
assert os.path.exists(ra_aid_path_str), f".ra-aid directory not found at {ra_aid_path_str}"
|
||||
assert os.path.isdir(ra_aid_path_str), f"{ra_aid_path_str} exists but is not a directory"
|
||||
|
||||
# Create a test file to verify write permissions
|
||||
test_file_path = os.path.join(ra_aid_path_str, "test_write.txt")
|
||||
print(f"Creating test file to verify write permissions: {test_file_path}")
|
||||
with open(test_file_path, 'w') as f:
|
||||
f.write("Test write permissions")
|
||||
|
||||
# Verify the test file was created
|
||||
assert os.path.exists(test_file_path), f"Test file not created at {test_file_path}"
|
||||
|
||||
# Create an empty database file to ensure it exists before init_db
|
||||
db_file_str = os.path.join(ra_aid_path_str, "pk.db")
|
||||
print(f"Creating empty database file at: {db_file_str}")
|
||||
with open(db_file_str, 'w') as f:
|
||||
f.write("") # Create empty file
|
||||
|
||||
# Verify the database file was created
|
||||
assert os.path.exists(db_file_str), f"Empty database file not created at {db_file_str}"
|
||||
print(f"Empty database file size: {os.path.getsize(db_file_str)} bytes")
|
||||
|
||||
# Get directory permissions for debugging
|
||||
dir_perms = oct(os.stat(ra_aid_path_str).st_mode)[-3:]
|
||||
print(f"Directory permissions: {dir_perms}")
|
||||
|
||||
# Initialize the database
|
||||
print("Calling init_db()")
|
||||
db = init_db()
|
||||
print("init_db() returned successfully")
|
||||
|
||||
# List contents of the current directory for debugging
|
||||
print(f"Contents of current directory: {os.listdir(current_cwd)}")
|
||||
|
||||
# List contents of the .ra-aid directory for debugging
|
||||
print(f"Contents of .ra-aid directory: {os.listdir(ra_aid_path_str)}")
|
||||
|
||||
# Check that the database file exists using os.path
|
||||
assert os.path.exists(db_file_str), f"Database file not found at {db_file_str}"
|
||||
assert os.path.isfile(db_file_str), f"{db_file_str} exists but is not a file"
|
||||
print(f"Final database file size: {os.path.getsize(db_file_str)} bytes")
|
||||
|
||||
|
||||
def test_init_db_creates_database_file(cleanup_db, tmp_path):
|
||||
"""
|
||||
Test that init_db creates the database file.
|
||||
"""
|
||||
# Change to the temporary directory
|
||||
os.chdir(tmp_path)
|
||||
|
||||
# Initialize the database
|
||||
init_db()
|
||||
|
||||
# Check that the database file was created
|
||||
assert (tmp_path / ".ra-aid" / "pk.db").exists()
|
||||
assert (tmp_path / ".ra-aid" / "pk.db").is_file()
|
||||
|
||||
|
||||
def test_init_db_returns_database_connection(cleanup_db):
|
||||
"""
|
||||
Test that init_db returns a database connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
# Check that the database connection is returned
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
|
||||
|
||||
def test_init_db_with_in_memory_mode(cleanup_db):
|
||||
"""
|
||||
Test that init_db with in_memory=True creates an in-memory database.
|
||||
"""
|
||||
# Initialize the database in in-memory mode
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Check that the database connection is returned
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is True
|
||||
|
||||
|
||||
def test_in_memory_mode_no_directory_created(cleanup_db, tmp_path):
|
||||
"""
|
||||
Test that when using in-memory mode, no directory is created.
|
||||
"""
|
||||
# Change to the temporary directory
|
||||
os.chdir(tmp_path)
|
||||
|
||||
# Initialize the database in in-memory mode
|
||||
init_db(in_memory=True)
|
||||
|
||||
# Check that the .ra-aid directory was not created
|
||||
# (Note: it might be created by other tests, so we can't assert it doesn't exist)
|
||||
# Instead, check that the database file was not created
|
||||
assert not (tmp_path / ".ra-aid" / "pk.db").exists()
|
||||
|
||||
|
||||
def test_init_db_returns_existing_connection(cleanup_db):
|
||||
"""
|
||||
Test that init_db returns the existing connection if one exists.
|
||||
"""
|
||||
# Initialize the database
|
||||
db1 = init_db()
|
||||
|
||||
# Initialize the database again
|
||||
db2 = init_db()
|
||||
|
||||
# Check that the same connection is returned
|
||||
assert db1 is db2
|
||||
|
||||
|
||||
def test_init_db_reopens_closed_connection(cleanup_db):
|
||||
"""
|
||||
Test that init_db reopens a closed connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
db1 = init_db()
|
||||
|
||||
# Close the connection
|
||||
db1.close()
|
||||
|
||||
# Initialize the database again
|
||||
db2 = init_db()
|
||||
|
||||
# Check that the same connection is returned and it's open
|
||||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
|
||||
def test_get_db_initializes_connection(cleanup_db):
|
||||
"""
|
||||
Test that get_db initializes a connection if none exists.
|
||||
"""
|
||||
# Get the database connection
|
||||
db = get_db()
|
||||
|
||||
# Check that a connection was initialized
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
|
||||
|
||||
def test_get_db_returns_existing_connection(cleanup_db):
|
||||
"""
|
||||
Test that get_db returns the existing connection if one exists.
|
||||
"""
|
||||
# Initialize the database
|
||||
db1 = init_db()
|
||||
|
||||
# Get the database connection
|
||||
db2 = get_db()
|
||||
|
||||
# Check that the same connection is returned
|
||||
assert db1 is db2
|
||||
|
||||
|
||||
def test_get_db_reopens_closed_connection(cleanup_db):
|
||||
"""
|
||||
Test that get_db reopens a closed connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
# Close the connection
|
||||
db.close()
|
||||
|
||||
# Get the database connection
|
||||
db2 = get_db()
|
||||
|
||||
# Check that the same connection is returned and it's open
|
||||
assert db is db2
|
||||
assert not db.is_closed()
|
||||
|
||||
|
||||
def test_get_db_handles_reopen_error(cleanup_db, monkeypatch):
|
||||
"""
|
||||
Test that get_db handles errors when reopening a connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
# Close the connection
|
||||
db.close()
|
||||
|
||||
# Create a patched version of the connect method that raises an error
|
||||
original_connect = peewee.SqliteDatabase.connect
|
||||
|
||||
def mock_connect(self, *args, **kwargs):
|
||||
if self is db: # Only raise for the specific db instance
|
||||
raise peewee.OperationalError("Test error")
|
||||
return original_connect(self, *args, **kwargs)
|
||||
|
||||
# Apply the patch
|
||||
monkeypatch.setattr(peewee.SqliteDatabase, 'connect', mock_connect)
|
||||
|
||||
# Get the database connection
|
||||
db2 = get_db()
|
||||
|
||||
# Check that a new connection was initialized
|
||||
assert db is not db2
|
||||
assert not db2.is_closed()
|
||||
|
||||
|
||||
def test_close_db_closes_connection(cleanup_db):
|
||||
"""
|
||||
Test that close_db closes the connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
# Close the connection
|
||||
close_db()
|
||||
|
||||
# Check that the connection is closed
|
||||
assert db.is_closed()
|
||||
|
||||
|
||||
def test_close_db_handles_no_connection():
|
||||
"""
|
||||
Test that close_db handles the case where no connection exists.
|
||||
"""
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
|
||||
# Close the connection (should not raise an error)
|
||||
close_db()
|
||||
|
||||
|
||||
def test_close_db_handles_already_closed_connection(cleanup_db):
|
||||
"""
|
||||
Test that close_db handles the case where the connection is already closed.
|
||||
"""
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
# Close the connection
|
||||
db.close()
|
||||
|
||||
# Close the connection again (should not raise an error)
|
||||
close_db()
|
||||
|
||||
|
||||
@patch('ra_aid.database.connection.peewee.SqliteDatabase.close')
|
||||
def test_close_db_handles_error(mock_close, cleanup_db):
|
||||
"""
|
||||
Test that close_db handles errors when closing the connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
init_db()
|
||||
|
||||
# Make close raise an error
|
||||
mock_close.side_effect = peewee.DatabaseError("Test error")
|
||||
|
||||
# Close the connection (should not raise an error)
|
||||
close_db()
|
||||
|
||||
|
||||
def test_database_manager_context_manager(cleanup_db):
|
||||
"""
|
||||
Test that DatabaseManager works as a context manager.
|
||||
"""
|
||||
# Use the context manager
|
||||
with DatabaseManager() as db:
|
||||
# Check that a connection was initialized
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
|
||||
# Store the connection for later
|
||||
db_in_context = db
|
||||
|
||||
# Check that the connection is closed after exiting the context
|
||||
assert db_in_context.is_closed()
|
||||
|
||||
|
||||
def test_database_manager_with_in_memory_mode(cleanup_db):
|
||||
"""
|
||||
Test that DatabaseManager with in_memory=True creates an in-memory database.
|
||||
"""
|
||||
# Use the context manager with in_memory=True
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Check that a connection was initialized
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is True
|
||||
|
||||
|
||||
def test_init_db_shows_message_only_once(cleanup_db, caplog):
|
||||
"""
|
||||
Test that init_db only shows the initialization message once.
|
||||
"""
|
||||
# Initialize the database
|
||||
init_db(in_memory=True)
|
||||
|
||||
# Clear the log
|
||||
caplog.clear()
|
||||
|
||||
# Initialize the database again
|
||||
init_db(in_memory=True)
|
||||
|
||||
# Check that no message was logged
|
||||
assert "database connection initialized" not in caplog.text.lower()
|
||||
|
||||
|
||||
def test_init_db_sets_is_in_memory_attribute(cleanup_db):
|
||||
"""
|
||||
Test that init_db sets the _is_in_memory attribute.
|
||||
"""
|
||||
# Initialize the database with in_memory=False
|
||||
db = init_db(in_memory=False)
|
||||
|
||||
# Check that the _is_in_memory attribute is set to False
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is False
|
||||
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
|
||||
# Initialize the database with in_memory=True
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Check that the _is_in_memory attribute is set to True
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is True
|
||||
"""
|
||||
Tests for the database connection module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import (
|
||||
init_db, get_db, close_db,
|
||||
db_var, DatabaseManager, logger
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""
|
||||
Fixture to clean up database connections and files between tests.
|
||||
|
||||
This fixture:
|
||||
1. Closes any open database connection
|
||||
2. Resets the contextvar
|
||||
3. Cleans up the .ra-aid directory
|
||||
"""
|
||||
# Store the current working directory
|
||||
original_cwd = os.getcwd()
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Clean up after the test
|
||||
try:
|
||||
# Close any open database connection
|
||||
close_db()
|
||||
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
|
||||
# Clean up the .ra-aid directory if it exists
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
ra_aid_dir_str = str(ra_aid_dir.absolute())
|
||||
|
||||
# Check using both methods
|
||||
path_exists = ra_aid_dir.exists()
|
||||
os_exists = os.path.exists(ra_aid_dir_str)
|
||||
|
||||
print(f"Cleanup check: Path.exists={path_exists}, os.path.exists={os_exists}")
|
||||
|
||||
if os_exists:
|
||||
# Only remove the database file, not the entire directory
|
||||
db_file = os.path.join(ra_aid_dir_str, "pk.db")
|
||||
if os.path.exists(db_file):
|
||||
os.unlink(db_file)
|
||||
|
||||
# Remove WAL and SHM files if they exist
|
||||
wal_file = os.path.join(ra_aid_dir_str, "pk.db-wal")
|
||||
if os.path.exists(wal_file):
|
||||
os.unlink(wal_file)
|
||||
|
||||
shm_file = os.path.join(ra_aid_dir_str, "pk.db-shm")
|
||||
if os.path.exists(shm_file):
|
||||
os.unlink(shm_file)
|
||||
|
||||
# List remaining contents for debugging
|
||||
if os.path.exists(ra_aid_dir_str):
|
||||
print(f"Directory contents after cleanup: {os.listdir(ra_aid_dir_str)}")
|
||||
except Exception as e:
|
||||
# Log but don't fail if cleanup has issues
|
||||
print(f"Cleanup error (non-fatal): {str(e)}")
|
||||
|
||||
# Make sure we're back in the original directory
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
class TestInitDb:
|
||||
"""Tests for the init_db function."""
|
||||
|
||||
def test_init_db_default(self, cleanup_db):
|
||||
"""Test init_db with default parameters."""
|
||||
# Get the absolute path of the current working directory
|
||||
cwd = os.getcwd()
|
||||
print(f"Current working directory: {cwd}")
|
||||
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is False
|
||||
|
||||
# Verify the database file was created using both Path and os.path methods
|
||||
ra_aid_dir = Path(cwd) / ".ra-aid"
|
||||
ra_aid_dir_str = str(ra_aid_dir.absolute())
|
||||
|
||||
# Check directory existence using both methods
|
||||
path_exists = ra_aid_dir.exists()
|
||||
os_exists = os.path.exists(ra_aid_dir_str)
|
||||
print(f"Directory check: Path.exists={path_exists}, os.path.exists={os_exists}")
|
||||
|
||||
# List the contents of the current directory
|
||||
print(f"Contents of {cwd}: {os.listdir(cwd)}")
|
||||
|
||||
# If the directory exists, list its contents
|
||||
if os_exists:
|
||||
print(f"Contents of {ra_aid_dir_str}: {os.listdir(ra_aid_dir_str)}")
|
||||
|
||||
# Use os.path for assertions to be more reliable
|
||||
assert os.path.exists(ra_aid_dir_str), f"Directory {ra_aid_dir_str} does not exist"
|
||||
assert os.path.isdir(ra_aid_dir_str), f"{ra_aid_dir_str} is not a directory"
|
||||
|
||||
db_file = os.path.join(ra_aid_dir_str, "pk.db")
|
||||
assert os.path.exists(db_file), f"Database file {db_file} does not exist"
|
||||
assert os.path.isfile(db_file), f"{db_file} is not a file"
|
||||
|
||||
def test_init_db_in_memory(self, cleanup_db):
|
||||
"""Test init_db with in_memory=True."""
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is True
|
||||
|
||||
def test_init_db_reuses_connection(self, cleanup_db):
|
||||
"""Test that init_db reuses an existing connection."""
|
||||
db1 = init_db()
|
||||
db2 = init_db()
|
||||
|
||||
assert db1 is db2
|
||||
|
||||
def test_init_db_reopens_closed_connection(self, cleanup_db):
|
||||
"""Test that init_db reopens a closed connection."""
|
||||
db1 = init_db()
|
||||
db1.close()
|
||||
assert db1.is_closed()
|
||||
|
||||
db2 = init_db()
|
||||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
|
||||
class TestGetDb:
|
||||
"""Tests for the get_db function."""
|
||||
|
||||
def test_get_db_creates_connection(self, cleanup_db):
|
||||
"""Test that get_db creates a new connection if none exists."""
|
||||
# Reset the contextvar to ensure no connection exists
|
||||
db_var.set(None)
|
||||
|
||||
db = get_db()
|
||||
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is False
|
||||
|
||||
def test_get_db_reuses_connection(self, cleanup_db):
|
||||
"""Test that get_db reuses an existing connection."""
|
||||
db1 = init_db()
|
||||
db2 = get_db()
|
||||
|
||||
assert db1 is db2
|
||||
|
||||
def test_get_db_reopens_closed_connection(self, cleanup_db):
|
||||
"""Test that get_db reopens a closed connection."""
|
||||
db1 = init_db()
|
||||
db1.close()
|
||||
assert db1.is_closed()
|
||||
|
||||
db2 = get_db()
|
||||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
|
||||
class TestCloseDb:
|
||||
"""Tests for the close_db function."""
|
||||
|
||||
def test_close_db(self, cleanup_db):
|
||||
"""Test that close_db closes an open connection."""
|
||||
db = init_db()
|
||||
assert not db.is_closed()
|
||||
|
||||
close_db()
|
||||
assert db.is_closed()
|
||||
|
||||
def test_close_db_no_connection(self, cleanup_db):
|
||||
"""Test that close_db handles the case where no connection exists."""
|
||||
# Reset the contextvar to ensure no connection exists
|
||||
db_var.set(None)
|
||||
|
||||
# This should not raise an exception
|
||||
close_db()
|
||||
|
||||
def test_close_db_already_closed(self, cleanup_db):
|
||||
"""Test that close_db handles the case where the connection is already closed."""
|
||||
db = init_db()
|
||||
db.close()
|
||||
assert db.is_closed()
|
||||
|
||||
# This should not raise an exception
|
||||
close_db()
|
||||
|
||||
|
||||
class TestDatabaseManager:
|
||||
"""Tests for the DatabaseManager class."""
|
||||
|
||||
def test_database_manager_default(self, cleanup_db):
|
||||
"""Test DatabaseManager with default parameters."""
|
||||
with DatabaseManager() as db:
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is False
|
||||
|
||||
# Verify the database file was created
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
assert ra_aid_dir.exists()
|
||||
assert (ra_aid_dir / "pk.db").exists()
|
||||
|
||||
# Verify the connection is closed after exiting the context
|
||||
assert db.is_closed()
|
||||
|
||||
def test_database_manager_in_memory(self, cleanup_db):
|
||||
"""Test DatabaseManager with in_memory=True."""
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is True
|
||||
|
||||
# Verify the connection is closed after exiting the context
|
||||
assert db.is_closed()
|
||||
|
||||
def test_database_manager_exception_handling(self, cleanup_db):
|
||||
"""Test that DatabaseManager properly handles exceptions."""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
assert not db.is_closed()
|
||||
raise ValueError("Test exception")
|
||||
except ValueError:
|
||||
# The exception should be propagated
|
||||
pass
|
||||
|
||||
# Verify the connection is closed even if an exception occurred
|
||||
assert db.is_closed()
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
"""
|
||||
Database utility functions for ra_aid.
|
||||
|
||||
This module provides utility functions for common database operations.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import List, Type
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import get_db
|
||||
from ra_aid.database.models import BaseModel
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
|
||||
"""
|
||||
Ensure that database tables for the specified models exist.
|
||||
|
||||
If no models are specified, this function will attempt to discover
|
||||
all models that inherit from BaseModel.
|
||||
|
||||
Args:
|
||||
models: Optional list of model classes to create tables for
|
||||
"""
|
||||
db = get_db()
|
||||
|
||||
if models is None:
|
||||
# If no models are specified, try to discover them
|
||||
models = []
|
||||
try:
|
||||
# Import all modules that might contain models
|
||||
# This is a placeholder - in a real implementation, you would
|
||||
# dynamically discover and import all modules that might contain models
|
||||
from ra_aid.database import models as models_module
|
||||
|
||||
# Find all classes in the module that inherit from BaseModel
|
||||
for name, obj in inspect.getmembers(models_module):
|
||||
if (inspect.isclass(obj) and
|
||||
issubclass(obj, BaseModel) and
|
||||
obj != BaseModel):
|
||||
models.append(obj)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Error importing model modules: {e}")
|
||||
|
||||
if not models:
|
||||
logger.warning("No models found to create tables for")
|
||||
return
|
||||
|
||||
try:
|
||||
with db.atomic():
|
||||
db.create_tables(models, safe=True)
|
||||
logger.info(f"Successfully created tables for {len(models)} models")
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Database Error: Failed to create tables: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error: Failed to create tables: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_model_count(model_class: Type[BaseModel]) -> int:
|
||||
"""
|
||||
Get the count of records for a specific model.
|
||||
|
||||
Args:
|
||||
model_class: The model class to count records for
|
||||
|
||||
Returns:
|
||||
int: The number of records for the model
|
||||
"""
|
||||
try:
|
||||
return model_class.select().count()
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Database Error: Failed to count records: {str(e)}")
|
||||
return 0
|
||||
|
||||
def truncate_table(model_class: Type[BaseModel]) -> None:
|
||||
"""
|
||||
Delete all records from a model's table.
|
||||
|
||||
Args:
|
||||
model_class: The model class to truncate
|
||||
"""
|
||||
db = get_db()
|
||||
try:
|
||||
with db.atomic():
|
||||
model_class.delete().execute()
|
||||
logger.info(f"Successfully truncated table for {model_class.__name__}")
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Database Error: Failed to truncate table: {str(e)}")
|
||||
raise
|
||||
|
|
@ -0,0 +1 @@
|
|||
# This file is intentionally left empty to make the directory a Python package
|
||||
|
|
@ -0,0 +1 @@
|
|||
# This file is intentionally left empty to make the directory a Python package
|
||||
|
|
@ -0,0 +1 @@
|
|||
# This file is intentionally left empty to make the directory a Python package
|
||||
|
|
@ -0,0 +1,181 @@
|
|||
"""
|
||||
Tests for the database connection module.
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import peewee
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from ra_aid.database.connection import (
|
||||
init_db, get_db, close_db,
|
||||
db_var, DatabaseManager, logger
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""
|
||||
Fixture to clean up database connections and files between tests.
|
||||
This fixture:
|
||||
1. Closes any open database connection
|
||||
2. Resets the contextvar
|
||||
3. Cleans up the .ra-aid directory
|
||||
"""
|
||||
# Run the test
|
||||
yield
|
||||
# Clean up after the test
|
||||
try:
|
||||
# Close any open database connection
|
||||
close_db()
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
# Clean up the .ra-aid directory if it exists
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
if ra_aid_dir.exists():
|
||||
# Only remove the database file, not the entire directory
|
||||
db_file = ra_aid_dir / "pk.db"
|
||||
if db_file.exists():
|
||||
db_file.unlink()
|
||||
# Remove WAL and SHM files if they exist
|
||||
wal_file = ra_aid_dir / "pk.db-wal"
|
||||
if wal_file.exists():
|
||||
wal_file.unlink()
|
||||
shm_file = ra_aid_dir / "pk.db-shm"
|
||||
if shm_file.exists():
|
||||
shm_file.unlink()
|
||||
except Exception as e:
|
||||
# Log but don't fail if cleanup has issues
|
||||
print(f"Cleanup error (non-fatal): {str(e)}")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger():
|
||||
"""Mock the logger to test for output messages."""
|
||||
with patch('ra_aid.database.connection.logger') as mock:
|
||||
yield mock
|
||||
|
||||
class TestInitDb:
|
||||
"""Tests for the init_db function."""
|
||||
def test_init_db_default(self, cleanup_db):
|
||||
"""Test init_db with default parameters."""
|
||||
db = init_db()
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is False
|
||||
# Verify the database file was created
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
assert ra_aid_dir.exists()
|
||||
assert (ra_aid_dir / "pk.db").exists()
|
||||
|
||||
def test_init_db_in_memory(self, cleanup_db):
|
||||
"""Test init_db with in_memory=True."""
|
||||
db = init_db(in_memory=True)
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is True
|
||||
|
||||
def test_init_db_reuses_connection(self, cleanup_db):
|
||||
"""Test that init_db reuses an existing connection."""
|
||||
db1 = init_db()
|
||||
db2 = init_db()
|
||||
assert db1 is db2
|
||||
|
||||
def test_init_db_reopens_closed_connection(self, cleanup_db):
|
||||
"""Test that init_db reopens a closed connection."""
|
||||
db1 = init_db()
|
||||
db1.close()
|
||||
assert db1.is_closed()
|
||||
db2 = init_db()
|
||||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
class TestGetDb:
|
||||
"""Tests for the get_db function."""
|
||||
def test_get_db_creates_connection(self, cleanup_db):
|
||||
"""Test that get_db creates a new connection if none exists."""
|
||||
# Reset the contextvar to ensure no connection exists
|
||||
db_var.set(None)
|
||||
db = get_db()
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is False
|
||||
|
||||
def test_get_db_reuses_connection(self, cleanup_db):
|
||||
"""Test that get_db reuses an existing connection."""
|
||||
db1 = init_db()
|
||||
db2 = get_db()
|
||||
assert db1 is db2
|
||||
|
||||
def test_get_db_reopens_closed_connection(self, cleanup_db):
|
||||
"""Test that get_db reopens a closed connection."""
|
||||
db1 = init_db()
|
||||
db1.close()
|
||||
assert db1.is_closed()
|
||||
db2 = get_db()
|
||||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
class TestCloseDb:
|
||||
"""Tests for the close_db function."""
|
||||
def test_close_db(self, cleanup_db):
|
||||
"""Test that close_db closes an open connection."""
|
||||
db = init_db()
|
||||
assert not db.is_closed()
|
||||
close_db()
|
||||
assert db.is_closed()
|
||||
|
||||
def test_close_db_no_connection(self, cleanup_db):
|
||||
"""Test that close_db handles the case where no connection exists."""
|
||||
# Reset the contextvar to ensure no connection exists
|
||||
db_var.set(None)
|
||||
# This should not raise an exception
|
||||
close_db()
|
||||
|
||||
def test_close_db_already_closed(self, cleanup_db):
|
||||
"""Test that close_db handles the case where the connection is already closed."""
|
||||
db = init_db()
|
||||
db.close()
|
||||
assert db.is_closed()
|
||||
# This should not raise an exception
|
||||
close_db()
|
||||
|
||||
class TestDatabaseManager:
|
||||
"""Tests for the DatabaseManager class."""
|
||||
def test_database_manager_default(self, cleanup_db):
|
||||
"""Test DatabaseManager with default parameters."""
|
||||
with DatabaseManager() as db:
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is False
|
||||
# Verify the database file was created
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
assert ra_aid_dir.exists()
|
||||
assert (ra_aid_dir / "pk.db").exists()
|
||||
# Verify the connection is closed after exiting the context
|
||||
assert db.is_closed()
|
||||
|
||||
def test_database_manager_in_memory(self, cleanup_db):
|
||||
"""Test DatabaseManager with in_memory=True."""
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, '_is_in_memory')
|
||||
assert db._is_in_memory is True
|
||||
# Verify the connection is closed after exiting the context
|
||||
assert db.is_closed()
|
||||
|
||||
def test_database_manager_exception_handling(self, cleanup_db):
|
||||
"""Test that DatabaseManager properly handles exceptions."""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
assert not db.is_closed()
|
||||
raise ValueError("Test exception")
|
||||
except ValueError:
|
||||
# The exception should be propagated
|
||||
pass
|
||||
# Verify the connection is closed even if an exception occurred
|
||||
assert db.is_closed()
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
"""
|
||||
Tests for the database models module.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.models import BaseModel
|
||||
from ra_aid.database.connection import (
|
||||
db_var, get_db, init_db, close_db
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""Reset the database contextvar and connection state after each test."""
|
||||
# Reset before the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_test_model(cleanup_db):
|
||||
"""Set up a test model class for testing."""
|
||||
# Initialize an in-memory database connection
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Define a test model
|
||||
class TestModel(BaseModel):
|
||||
name = peewee.CharField()
|
||||
value = peewee.IntegerField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
|
||||
# Create the table
|
||||
with db.atomic():
|
||||
db.create_tables([TestModel], safe=True)
|
||||
|
||||
yield TestModel
|
||||
|
||||
# Clean up
|
||||
with db.atomic():
|
||||
TestModel.drop_table(safe=True)
|
||||
|
||||
|
||||
def test_base_model_save_updates_timestamps(setup_test_model):
|
||||
"""Test that saving a model updates the timestamps."""
|
||||
TestModel = setup_test_model
|
||||
|
||||
# Create a new instance
|
||||
instance = TestModel(name="test", value=42)
|
||||
instance.save()
|
||||
|
||||
# Check that created_at and updated_at are set
|
||||
assert instance.created_at is not None
|
||||
assert instance.updated_at is not None
|
||||
|
||||
# Store the original timestamps
|
||||
original_created_at = instance.created_at
|
||||
original_updated_at = instance.updated_at
|
||||
|
||||
# Wait a moment to ensure timestamps would be different
|
||||
import time
|
||||
time.sleep(0.001)
|
||||
|
||||
# Update the instance
|
||||
instance.value = 43
|
||||
instance.save()
|
||||
|
||||
# Check that updated_at changed but created_at didn't
|
||||
assert instance.created_at == original_created_at
|
||||
assert instance.updated_at != original_updated_at
|
||||
|
||||
|
||||
def test_base_model_get_or_create(setup_test_model):
|
||||
"""Test the get_or_create method."""
|
||||
TestModel = setup_test_model
|
||||
|
||||
# First call should create a new instance
|
||||
instance, created = TestModel.get_or_create(name="test", value=42)
|
||||
assert created is True
|
||||
assert instance.name == "test"
|
||||
assert instance.value == 42
|
||||
|
||||
# Second call with same parameters should return existing instance
|
||||
instance2, created2 = TestModel.get_or_create(name="test", value=42)
|
||||
assert created2 is False
|
||||
assert instance2.id == instance.id
|
||||
|
||||
# Call with different parameters should create a new instance
|
||||
instance3, created3 = TestModel.get_or_create(name="test2", value=43)
|
||||
assert created3 is True
|
||||
assert instance3.id != instance.id
|
||||
|
||||
|
||||
@patch('ra_aid.database.models.logger')
|
||||
def test_base_model_get_or_create_handles_errors(mock_logger, setup_test_model):
|
||||
"""Test that get_or_create handles database errors properly."""
|
||||
TestModel = setup_test_model
|
||||
|
||||
# Mock the parent get_or_create to raise a DatabaseError
|
||||
with patch('peewee.Model.get_or_create', side_effect=peewee.DatabaseError("Test error")):
|
||||
# Call should raise the error
|
||||
with pytest.raises(peewee.DatabaseError):
|
||||
TestModel.get_or_create(name="test")
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_with("Failed in get_or_create: Test error")
|
||||
|
|
@ -0,0 +1,220 @@
|
|||
"""
|
||||
Tests for the database utils module.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import (
|
||||
db_var, get_db, init_db, close_db
|
||||
)
|
||||
from ra_aid.database.models import BaseModel
|
||||
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""Reset the database contextvar and connection state after each test."""
|
||||
# Reset before the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger():
|
||||
"""Mock the logger to test for output messages."""
|
||||
with patch('ra_aid.database.utils.logger') as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_test_model(cleanup_db):
|
||||
"""Set up a test model for database tests."""
|
||||
# Initialize the database in memory
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Define a test model class
|
||||
class TestModel(BaseModel):
|
||||
name = peewee.CharField(max_length=100)
|
||||
value = peewee.IntegerField(default=0)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
|
||||
# Create the test table in a transaction
|
||||
with db.atomic():
|
||||
db.create_tables([TestModel], safe=True)
|
||||
|
||||
# Yield control to the test
|
||||
yield TestModel
|
||||
|
||||
# Clean up: drop the test table
|
||||
with db.atomic():
|
||||
db.drop_tables([TestModel], safe=True)
|
||||
|
||||
|
||||
def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
|
||||
"""Test ensure_tables_created with explicit models."""
|
||||
# Initialize the database in memory
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Define a test model that uses this database
|
||||
class TestModel(BaseModel):
|
||||
name = peewee.CharField(max_length=100)
|
||||
value = peewee.IntegerField(default=0)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
|
||||
# Call ensure_tables_created with explicit models
|
||||
ensure_tables_created([TestModel])
|
||||
|
||||
# Verify success message was logged
|
||||
mock_logger.info.assert_called_with("Successfully created tables for 1 models")
|
||||
|
||||
# Verify the table exists by trying to use it
|
||||
TestModel.create(name="test", value=42)
|
||||
count = TestModel.select().count()
|
||||
assert count == 1
|
||||
|
||||
|
||||
def test_ensure_tables_created_no_models(cleanup_db, mock_logger):
|
||||
"""Test ensure_tables_created with no models."""
|
||||
# Initialize the database in memory
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Mock the import to simulate no models found
|
||||
with patch('ra_aid.database.utils.importlib.import_module', side_effect=ImportError("No module")):
|
||||
# Call ensure_tables_created with no models
|
||||
ensure_tables_created()
|
||||
|
||||
# Verify warning message was logged
|
||||
mock_logger.warning.assert_called_with("No models found to create tables for")
|
||||
|
||||
|
||||
@patch('ra_aid.database.utils.get_db')
|
||||
def test_ensure_tables_created_database_error(mock_get_db, setup_test_model, cleanup_db, mock_logger):
|
||||
"""Test ensure_tables_created handles database errors."""
|
||||
# Get the TestModel class from the fixture
|
||||
TestModel = setup_test_model
|
||||
|
||||
# Create a mock database with a create_tables method that raises an error
|
||||
mock_db = MagicMock()
|
||||
mock_db.atomic.return_value.__enter__.return_value = None
|
||||
mock_db.atomic.return_value.__exit__.return_value = None
|
||||
mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error")
|
||||
|
||||
# Configure get_db to return our mock
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
# Call ensure_tables_created and expect an exception
|
||||
with pytest.raises(peewee.DatabaseError):
|
||||
ensure_tables_created([TestModel])
|
||||
|
||||
# Verify error message was logged
|
||||
mock_logger.error.assert_called_with("Database Error: Failed to create tables: Test database error")
|
||||
|
||||
|
||||
def test_get_model_count(setup_test_model, mock_logger):
|
||||
"""Test get_model_count returns the correct count."""
|
||||
# Get the TestModel class from the fixture
|
||||
TestModel = setup_test_model
|
||||
|
||||
# First ensure the table is empty
|
||||
TestModel.delete().execute()
|
||||
|
||||
# Create some test records
|
||||
TestModel.create(name="test1", value=1)
|
||||
TestModel.create(name="test2", value=2)
|
||||
|
||||
# Call get_model_count
|
||||
count = get_model_count(TestModel)
|
||||
|
||||
# Verify the count is correct
|
||||
assert count == 2
|
||||
|
||||
|
||||
@patch('peewee.ModelSelect.count')
|
||||
def test_get_model_count_database_error(mock_count, setup_test_model, mock_logger):
|
||||
"""Test get_model_count handles database errors."""
|
||||
# Get the TestModel class from the fixture
|
||||
TestModel = setup_test_model
|
||||
|
||||
# Configure the mock to raise a DatabaseError
|
||||
mock_count.side_effect = peewee.DatabaseError("Test count error")
|
||||
|
||||
# Call get_model_count
|
||||
count = get_model_count(TestModel)
|
||||
|
||||
# Verify error message was logged
|
||||
mock_logger.error.assert_called_with("Database Error: Failed to count records: Test count error")
|
||||
|
||||
# Verify the function returns 0 on error
|
||||
assert count == 0
|
||||
|
||||
|
||||
def test_truncate_table(setup_test_model, mock_logger):
|
||||
"""Test truncate_table deletes all records."""
|
||||
# Get the TestModel class from the fixture
|
||||
TestModel = setup_test_model
|
||||
|
||||
# Create some test records
|
||||
TestModel.create(name="test1", value=1)
|
||||
TestModel.create(name="test2", value=2)
|
||||
|
||||
# Verify records exist
|
||||
assert TestModel.select().count() == 2
|
||||
|
||||
# Call truncate_table
|
||||
truncate_table(TestModel)
|
||||
|
||||
# Verify success message was logged
|
||||
mock_logger.info.assert_called_with(f"Successfully truncated table for {TestModel.__name__}")
|
||||
|
||||
# Verify all records were deleted
|
||||
assert TestModel.select().count() == 0
|
||||
|
||||
|
||||
@patch('ra_aid.database.models.BaseModel.delete')
|
||||
def test_truncate_table_database_error(mock_delete, setup_test_model, mock_logger):
|
||||
"""Test truncate_table handles database errors."""
|
||||
# Get the TestModel class from the fixture
|
||||
TestModel = setup_test_model
|
||||
|
||||
# Create a test record
|
||||
TestModel.create(name="test", value=42)
|
||||
|
||||
# Configure the mock to return a mock query with execute that raises an error
|
||||
mock_query = MagicMock()
|
||||
mock_query.execute.side_effect = peewee.DatabaseError("Test delete error")
|
||||
mock_delete.return_value = mock_query
|
||||
|
||||
# Call truncate_table and expect an exception
|
||||
with pytest.raises(peewee.DatabaseError):
|
||||
truncate_table(TestModel)
|
||||
|
||||
# Verify error message was logged
|
||||
mock_logger.error.assert_called_with("Database Error: Failed to truncate table: Test delete error")
|
||||
Loading…
Reference in New Issue