key snippets db
This commit is contained in:
parent
be2eb298a5
commit
038e7b886c
|
|
@ -610,6 +610,9 @@ def main():
|
|||
memory=planning_memory,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Run cleanup tasks before exiting database context
|
||||
run_cleanup()
|
||||
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
print()
|
||||
|
|
@ -618,5 +621,17 @@ def main():
|
|||
sys.exit(0)
|
||||
|
||||
|
||||
def run_cleanup():
|
||||
"""Run cleanup tasks after main execution."""
|
||||
try:
|
||||
# Import the key facts cleaner agent
|
||||
from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent
|
||||
|
||||
# Run the key facts garbage collection agent regardless of the number of facts
|
||||
run_key_facts_gc_agent()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run cleanup tasks: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -419,7 +419,7 @@ def run_research_agent(
|
|||
project_info=formatted_project_info,
|
||||
new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "",
|
||||
)
|
||||
|
||||
|
||||
config = _global_memory.get("config", {}) if not config else config
|
||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
run_config = {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from ra_aid.database.migrations import (
|
|||
get_migration_status,
|
||||
init_migrations,
|
||||
)
|
||||
from ra_aid.database.models import BaseModel
|
||||
from ra_aid.database.models import BaseModel, initialize_database
|
||||
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -22,6 +22,7 @@ __all__ = [
|
|||
"close_db",
|
||||
"DatabaseManager",
|
||||
"BaseModel",
|
||||
"initialize_database",
|
||||
"get_model_count",
|
||||
"truncate_table",
|
||||
"ensure_tables_created",
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ import peewee
|
|||
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
# Import initialize_database after it's defined in models.py
|
||||
# We need to do the import inside functions to avoid circular imports
|
||||
|
||||
# Create contextvar to hold the database connection
|
||||
db_var = contextvars.ContextVar("db", default=None)
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -34,16 +37,23 @@ class DatabaseManager:
|
|||
# Or with in-memory database:
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Use in-memory database
|
||||
|
||||
# Or with custom base directory:
|
||||
with DatabaseManager(base_dir="/custom/path") as db:
|
||||
# Use database in custom directory
|
||||
"""
|
||||
|
||||
def __init__(self, in_memory: bool = False):
|
||||
def __init__(self, in_memory: bool = False, base_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the DatabaseManager.
|
||||
|
||||
Args:
|
||||
in_memory: Whether to use an in-memory database (default: False)
|
||||
base_dir: Optional base directory to use instead of current working directory.
|
||||
If None, uses os.getcwd() (default: None)
|
||||
"""
|
||||
self.in_memory = in_memory
|
||||
self.base_dir = base_dir
|
||||
|
||||
def __enter__(self) -> peewee.SqliteDatabase:
|
||||
"""
|
||||
|
|
@ -52,7 +62,19 @@ class DatabaseManager:
|
|||
Returns:
|
||||
peewee.SqliteDatabase: The initialized database connection
|
||||
"""
|
||||
return init_db(in_memory=self.in_memory)
|
||||
db = init_db(in_memory=self.in_memory, base_dir=self.base_dir)
|
||||
|
||||
# Initialize the database proxy in models.py
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.database.models import initialize_database
|
||||
initialize_database()
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import initialize_database: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing database proxy: {str(e)}")
|
||||
|
||||
return db
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
|
|
@ -74,7 +96,7 @@ class DatabaseManager:
|
|||
return False
|
||||
|
||||
|
||||
def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||
def init_db(in_memory: bool = False, base_dir: Optional[str] = None) -> peewee.SqliteDatabase:
|
||||
"""
|
||||
Initialize the database connection.
|
||||
|
||||
|
|
@ -84,6 +106,8 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
|
||||
Args:
|
||||
in_memory: Whether to use an in-memory database (default: False)
|
||||
base_dir: Optional base directory to use instead of current working directory.
|
||||
If None, uses os.getcwd() (default: None)
|
||||
|
||||
Returns:
|
||||
peewee.SqliteDatabase: The initialized database connection
|
||||
|
|
@ -110,9 +134,9 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
db_path = ":memory:"
|
||||
logger.debug("Using in-memory SQLite database")
|
||||
else:
|
||||
# Get current working directory and create .ra-aid directory if it doesn't exist
|
||||
cwd = os.getcwd()
|
||||
logger.debug(f"Current working directory: {cwd}")
|
||||
# Get base directory (use current working directory if not provided)
|
||||
cwd = base_dir if base_dir is not None else os.getcwd()
|
||||
logger.debug(f"Base directory for database: {cwd}")
|
||||
|
||||
# Define the .ra-aid directory path
|
||||
ra_aid_dir_str = os.path.join(cwd, ".ra-aid")
|
||||
|
|
@ -300,13 +324,17 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
|||
raise
|
||||
|
||||
|
||||
def get_db() -> peewee.SqliteDatabase:
|
||||
def get_db(base_dir: Optional[str] = None) -> peewee.SqliteDatabase:
|
||||
"""
|
||||
Get the current database connection.
|
||||
|
||||
If no connection exists, initializes a new one.
|
||||
If connection exists but is closed, reopens it.
|
||||
|
||||
Args:
|
||||
base_dir: Optional base directory to use instead of current working directory.
|
||||
If None, uses os.getcwd() (default: None)
|
||||
|
||||
Returns:
|
||||
peewee.SqliteDatabase: The current database connection
|
||||
"""
|
||||
|
|
@ -315,7 +343,7 @@ def get_db() -> peewee.SqliteDatabase:
|
|||
if db is None:
|
||||
# No database connection exists, initialize one
|
||||
# Use the default in-memory mode (False)
|
||||
return init_db(in_memory=False)
|
||||
return init_db(in_memory=False, base_dir=base_dir)
|
||||
|
||||
# Check if connection is closed and reopen if needed
|
||||
if db.is_closed():
|
||||
|
|
@ -332,7 +360,7 @@ def get_db() -> peewee.SqliteDatabase:
|
|||
in_memory = hasattr(db, "_is_in_memory") and db._is_in_memory
|
||||
logger.debug(f"Creating new database connection (in_memory={in_memory})")
|
||||
# Create a completely new database object, don't reuse the old one
|
||||
return init_db(in_memory=in_memory)
|
||||
return init_db(in_memory=in_memory, base_dir=base_dir)
|
||||
|
||||
return db
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,21 @@ from ra_aid.logging_config import get_logger
|
|||
T = TypeVar("T", bound="BaseModel")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Create a database proxy that will be initialized later
|
||||
database_proxy = peewee.DatabaseProxy()
|
||||
|
||||
|
||||
def initialize_database():
|
||||
"""
|
||||
Initialize the database proxy with a real database connection.
|
||||
|
||||
This function should be called before any database operations
|
||||
to ensure the proxy points to a real database connection.
|
||||
"""
|
||||
db = get_db()
|
||||
database_proxy.initialize(db)
|
||||
return db
|
||||
|
||||
|
||||
class BaseModel(peewee.Model):
|
||||
"""
|
||||
|
|
@ -28,7 +43,7 @@ class BaseModel(peewee.Model):
|
|||
updated_at = peewee.DateTimeField(default=datetime.datetime.now)
|
||||
|
||||
class Meta:
|
||||
database = get_db()
|
||||
database = database_proxy
|
||||
|
||||
def save(self, *args: Any, **kwargs: Any) -> int:
|
||||
"""
|
||||
|
|
@ -75,4 +90,22 @@ class KeyFact(BaseModel):
|
|||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
table_name = "key_fact"
|
||||
table_name = "key_fact"
|
||||
|
||||
|
||||
class KeySnippet(BaseModel):
|
||||
"""
|
||||
Model representing a key code snippet stored in the database.
|
||||
|
||||
Key snippets are important code fragments from the project that need to be
|
||||
referenced later. Each snippet includes its file location, line number,
|
||||
the code content itself, and an optional description of its significance.
|
||||
"""
|
||||
filepath = peewee.TextField()
|
||||
line_number = peewee.IntegerField()
|
||||
snippet = peewee.TextField()
|
||||
description = peewee.TextField(null=True)
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
table_name = "key_snippet"
|
||||
|
|
@ -10,7 +10,7 @@ from typing import Dict, List, Optional
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.connection import get_db
|
||||
from ra_aid.database.models import KeyFact
|
||||
from ra_aid.database.models import KeyFact, initialize_database
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -29,6 +29,15 @@ class KeyFactRepository:
|
|||
all_facts = repo.get_all()
|
||||
"""
|
||||
|
||||
def __init__(self, db=None):
|
||||
"""
|
||||
Initialize the repository with an optional database connection.
|
||||
|
||||
Args:
|
||||
db: Optional database connection to use. If None, will use initialize_database()
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def create(self, content: str) -> KeyFact:
|
||||
"""
|
||||
Create a new key fact in the database.
|
||||
|
|
@ -43,7 +52,7 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error creating the fact
|
||||
"""
|
||||
try:
|
||||
db = get_db()
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
fact = KeyFact.create(content=content)
|
||||
logger.debug(f"Created key fact ID {fact.id}: {content}")
|
||||
return fact
|
||||
|
|
@ -65,7 +74,7 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
db = get_db()
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
return KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
|
||||
|
|
@ -86,7 +95,7 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error updating the fact
|
||||
"""
|
||||
try:
|
||||
db = get_db()
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
|
|
@ -116,7 +125,7 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error deleting the fact
|
||||
"""
|
||||
try:
|
||||
db = get_db()
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
# First check if the fact exists
|
||||
fact = self.get(fact_id)
|
||||
if not fact:
|
||||
|
|
@ -142,7 +151,7 @@ class KeyFactRepository:
|
|||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
db = get_db()
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
return list(KeyFact.select().order_by(KeyFact.id))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all key facts: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,214 @@
|
|||
"""
|
||||
Key snippet repository implementation for database access.
|
||||
|
||||
This module provides a repository implementation for the KeySnippet model,
|
||||
following the repository pattern for data access abstraction.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import get_db
|
||||
from ra_aid.database.models import KeySnippet, initialize_database
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KeySnippetRepository:
|
||||
"""
|
||||
Repository for managing KeySnippet database operations.
|
||||
|
||||
This class provides methods for performing CRUD operations on the KeySnippet model,
|
||||
abstracting the database access details from the business logic.
|
||||
|
||||
Example:
|
||||
repo = KeySnippetRepository()
|
||||
snippet = repo.create(
|
||||
filepath="main.py",
|
||||
line_number=42,
|
||||
snippet="def hello_world():",
|
||||
description="Main function definition"
|
||||
)
|
||||
all_snippets = repo.get_all()
|
||||
"""
|
||||
|
||||
def __init__(self, db=None):
|
||||
"""
|
||||
Initialize the repository with an optional database connection.
|
||||
|
||||
Args:
|
||||
db: Optional database connection to use. If None, will use initialize_database()
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def create(
|
||||
self, filepath: str, line_number: int, snippet: str, description: Optional[str] = None
|
||||
) -> KeySnippet:
|
||||
"""
|
||||
Create a new key snippet in the database.
|
||||
|
||||
Args:
|
||||
filepath: Path to the source file
|
||||
line_number: Line number where the snippet starts
|
||||
snippet: The source code snippet text
|
||||
description: Optional description of the significance
|
||||
|
||||
Returns:
|
||||
KeySnippet: The newly created key snippet instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the snippet
|
||||
"""
|
||||
try:
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
key_snippet = KeySnippet.create(
|
||||
filepath=filepath,
|
||||
line_number=line_number,
|
||||
snippet=snippet,
|
||||
description=description
|
||||
)
|
||||
logger.debug(f"Created key snippet ID {key_snippet.id}: {filepath}:{line_number}")
|
||||
return key_snippet
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create key snippet: {str(e)}")
|
||||
raise
|
||||
|
||||
def get(self, snippet_id: int) -> Optional[KeySnippet]:
|
||||
"""
|
||||
Retrieve a key snippet by its ID.
|
||||
|
||||
Args:
|
||||
snippet_id: The ID of the key snippet to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[KeySnippet]: The key snippet instance if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
return KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
snippet_id: int,
|
||||
filepath: str,
|
||||
line_number: int,
|
||||
snippet: str,
|
||||
description: Optional[str] = None
|
||||
) -> Optional[KeySnippet]:
|
||||
"""
|
||||
Update an existing key snippet.
|
||||
|
||||
Args:
|
||||
snippet_id: The ID of the key snippet to update
|
||||
filepath: Path to the source file
|
||||
line_number: Line number where the snippet starts
|
||||
snippet: The source code snippet text
|
||||
description: Optional description of the significance
|
||||
|
||||
Returns:
|
||||
Optional[KeySnippet]: The updated key snippet if found, None otherwise
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error updating the snippet
|
||||
"""
|
||||
try:
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
# First check if the snippet exists
|
||||
key_snippet = self.get(snippet_id)
|
||||
if not key_snippet:
|
||||
logger.warning(f"Attempted to update non-existent key snippet {snippet_id}")
|
||||
return None
|
||||
|
||||
# Update the snippet
|
||||
key_snippet.filepath = filepath
|
||||
key_snippet.line_number = line_number
|
||||
key_snippet.snippet = snippet
|
||||
key_snippet.description = description
|
||||
key_snippet.save()
|
||||
logger.debug(f"Updated key snippet ID {snippet_id}: {filepath}:{line_number}")
|
||||
return key_snippet
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to update key snippet {snippet_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def delete(self, snippet_id: int) -> bool:
|
||||
"""
|
||||
Delete a key snippet by its ID.
|
||||
|
||||
Args:
|
||||
snippet_id: The ID of the key snippet to delete
|
||||
|
||||
Returns:
|
||||
bool: True if the snippet was deleted, False if it wasn't found
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error deleting the snippet
|
||||
"""
|
||||
try:
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
# First check if the snippet exists
|
||||
key_snippet = self.get(snippet_id)
|
||||
if not key_snippet:
|
||||
logger.warning(f"Attempted to delete non-existent key snippet {snippet_id}")
|
||||
return False
|
||||
|
||||
# Delete the snippet
|
||||
key_snippet.delete_instance()
|
||||
logger.debug(f"Deleted key snippet ID {snippet_id}")
|
||||
return True
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to delete key snippet {snippet_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all(self) -> List[KeySnippet]:
|
||||
"""
|
||||
Retrieve all key snippets from the database.
|
||||
|
||||
Returns:
|
||||
List[KeySnippet]: List of all key snippet instances
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
db = self.db if self.db is not None else initialize_database()
|
||||
return list(KeySnippet.select().order_by(KeySnippet.id))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch all key snippets: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_snippets_dict(self) -> Dict[int, Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve all key snippets as a dictionary mapping IDs to snippet information.
|
||||
|
||||
This method is useful for compatibility with the existing memory format.
|
||||
|
||||
Returns:
|
||||
Dict[int, Dict[str, Any]]: Dictionary with snippet IDs as keys and
|
||||
snippet information as values
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error accessing the database
|
||||
"""
|
||||
try:
|
||||
snippets = self.get_all()
|
||||
return {
|
||||
snippet.id: {
|
||||
"filepath": snippet.filepath,
|
||||
"line_number": snippet.line_number,
|
||||
"snippet": snippet.snippet,
|
||||
"description": snippet.description
|
||||
}
|
||||
for snippet in snippets
|
||||
}
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
|
||||
raise
|
||||
|
|
@ -1,5 +1,8 @@
|
|||
"""
|
||||
Tests for the database connection module.
|
||||
|
||||
This file tests the database connection functionality using pytest's fixtures
|
||||
for proper test isolation.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -21,642 +24,346 @@ from ra_aid.database.connection import (
|
|||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""
|
||||
Fixture to clean up database connections after tests.
|
||||
Fixture to clean up database connections between tests.
|
||||
|
||||
This ensures that we don't leak database connections between tests
|
||||
and that the db_var contextvar is reset.
|
||||
"""
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
|
||||
# Clean up after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
# Clean up attributes we may have added
|
||||
if hasattr(db, "_is_in_memory"):
|
||||
delattr(db, "_is_in_memory")
|
||||
if hasattr(db, "_message_shown"):
|
||||
delattr(db, "_message_shown")
|
||||
|
||||
# Close the connection if it's open
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
|
||||
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_in_memory_db():
|
||||
def db_path_mock(tmp_path, monkeypatch):
|
||||
"""
|
||||
Fixture to set up an in-memory database for testing.
|
||||
Fixture to mock os.getcwd() to return a temporary directory path.
|
||||
|
||||
This ensures that all database operations use the temporary directory
|
||||
and never touch the actual current working directory.
|
||||
"""
|
||||
# Initialize in-memory database
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Run the test
|
||||
yield db
|
||||
|
||||
# Clean up
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
def test_init_db_creates_directory(cleanup_db, tmp_path):
|
||||
"""
|
||||
Test that init_db creates the .ra-aid directory if it doesn't exist.
|
||||
"""
|
||||
# Get and print the original working directory
|
||||
original_cwd = os.getcwd()
|
||||
print(f"Original working directory: {original_cwd}")
|
||||
|
||||
# Convert tmp_path to string for consistent handling
|
||||
tmp_path_str = str(tmp_path.absolute())
|
||||
print(f"Temporary directory path: {tmp_path_str}")
|
||||
|
||||
# Change to the temporary directory
|
||||
os.chdir(tmp_path_str)
|
||||
current_cwd = os.getcwd()
|
||||
print(f"Changed working directory to: {current_cwd}")
|
||||
assert (
|
||||
current_cwd == tmp_path_str
|
||||
), f"Failed to change directory: {current_cwd} != {tmp_path_str}"
|
||||
|
||||
# Create the .ra-aid directory manually to ensure it exists
|
||||
ra_aid_path_str = os.path.join(current_cwd, ".ra-aid")
|
||||
print(f"Creating .ra-aid directory at: {ra_aid_path_str}")
|
||||
os.makedirs(ra_aid_path_str, exist_ok=True)
|
||||
|
||||
# Verify the directory was created
|
||||
assert os.path.exists(
|
||||
ra_aid_path_str
|
||||
), f".ra-aid directory not found at {ra_aid_path_str}"
|
||||
assert os.path.isdir(
|
||||
ra_aid_path_str
|
||||
), f"{ra_aid_path_str} exists but is not a directory"
|
||||
|
||||
# Create a test file to verify write permissions
|
||||
test_file_path = os.path.join(ra_aid_path_str, "test_write.txt")
|
||||
print(f"Creating test file to verify write permissions: {test_file_path}")
|
||||
with open(test_file_path, "w") as f:
|
||||
f.write("Test write permissions")
|
||||
|
||||
# Verify the test file was created
|
||||
assert os.path.exists(test_file_path), f"Test file not created at {test_file_path}"
|
||||
|
||||
# Create an empty database file to ensure it exists before init_db
|
||||
db_file_str = os.path.join(ra_aid_path_str, "pk.db")
|
||||
print(f"Creating empty database file at: {db_file_str}")
|
||||
with open(db_file_str, "w") as f:
|
||||
f.write("") # Create empty file
|
||||
|
||||
# Verify the database file was created
|
||||
assert os.path.exists(
|
||||
db_file_str
|
||||
), f"Empty database file not created at {db_file_str}"
|
||||
print(f"Empty database file size: {os.path.getsize(db_file_str)} bytes")
|
||||
|
||||
# Get directory permissions for debugging
|
||||
dir_perms = oct(os.stat(ra_aid_path_str).st_mode)[-3:]
|
||||
print(f"Directory permissions: {dir_perms}")
|
||||
|
||||
# Initialize the database
|
||||
print("Calling init_db()")
|
||||
db = init_db()
|
||||
print("init_db() returned successfully")
|
||||
|
||||
# List contents of the current directory for debugging
|
||||
print(f"Contents of current directory: {os.listdir(current_cwd)}")
|
||||
|
||||
# List contents of the .ra-aid directory for debugging
|
||||
print(f"Contents of .ra-aid directory: {os.listdir(ra_aid_path_str)}")
|
||||
|
||||
# Check that the database file exists using os.path
|
||||
assert os.path.exists(db_file_str), f"Database file not found at {db_file_str}"
|
||||
assert os.path.isfile(db_file_str), f"{db_file_str} exists but is not a file"
|
||||
print(f"Final database file size: {os.path.getsize(db_file_str)} bytes")
|
||||
|
||||
|
||||
def test_init_db_creates_database_file(cleanup_db, tmp_path):
|
||||
"""
|
||||
Test that init_db creates the database file.
|
||||
"""
|
||||
# Change to the temporary directory
|
||||
os.chdir(tmp_path)
|
||||
|
||||
# Initialize the database
|
||||
init_db()
|
||||
|
||||
# Check that the database file was created
|
||||
assert (tmp_path / ".ra-aid" / "pk.db").exists()
|
||||
assert (tmp_path / ".ra-aid" / "pk.db").is_file()
|
||||
|
||||
|
||||
def test_init_db_returns_database_connection(cleanup_db):
|
||||
"""
|
||||
Test that init_db returns a database connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
# Check that the database connection is returned
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
|
||||
|
||||
def test_init_db_with_in_memory_mode(cleanup_db):
|
||||
"""
|
||||
Test that init_db with in_memory=True creates an in-memory database.
|
||||
"""
|
||||
# Initialize the database in in-memory mode
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Check that the database connection is returned
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
|
||||
def test_in_memory_mode_no_directory_created(cleanup_db, tmp_path):
|
||||
"""
|
||||
Test that when using in-memory mode, no directory is created.
|
||||
"""
|
||||
# Change to the temporary directory
|
||||
os.chdir(tmp_path)
|
||||
|
||||
# Initialize the database in in-memory mode
|
||||
init_db(in_memory=True)
|
||||
|
||||
# Check that the .ra-aid directory was not created
|
||||
# (Note: it might be created by other tests, so we can't assert it doesn't exist)
|
||||
# Instead, check that the database file was not created
|
||||
assert not (tmp_path / ".ra-aid" / "pk.db").exists()
|
||||
|
||||
|
||||
def test_init_db_returns_existing_connection(cleanup_db):
|
||||
"""
|
||||
Test that init_db returns the existing connection if one exists.
|
||||
"""
|
||||
# Initialize the database
|
||||
db1 = init_db()
|
||||
|
||||
# Initialize the database again
|
||||
db2 = init_db()
|
||||
|
||||
# Check that the same connection is returned
|
||||
assert db1 is db2
|
||||
|
||||
|
||||
def test_init_db_reopens_closed_connection(cleanup_db):
|
||||
"""
|
||||
Test that init_db reopens a closed connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
db1 = init_db()
|
||||
|
||||
# Close the connection
|
||||
db1.close()
|
||||
|
||||
# Initialize the database again
|
||||
db2 = init_db()
|
||||
|
||||
# Check that the same connection is returned and it's open
|
||||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
|
||||
def test_get_db_initializes_connection(cleanup_db):
|
||||
"""
|
||||
Test that get_db initializes a connection if none exists.
|
||||
"""
|
||||
# Get the database connection
|
||||
db = get_db()
|
||||
|
||||
# Check that a connection was initialized
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
|
||||
|
||||
def test_get_db_returns_existing_connection(cleanup_db):
|
||||
"""
|
||||
Test that get_db returns the existing connection if one exists.
|
||||
"""
|
||||
# Initialize the database
|
||||
db1 = init_db()
|
||||
|
||||
# Get the database connection
|
||||
db2 = get_db()
|
||||
|
||||
# Check that the same connection is returned
|
||||
assert db1 is db2
|
||||
|
||||
|
||||
def test_get_db_reopens_closed_connection(cleanup_db):
|
||||
"""
|
||||
Test that get_db reopens a closed connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
# Close the connection
|
||||
db.close()
|
||||
|
||||
# Get the database connection
|
||||
db2 = get_db()
|
||||
|
||||
# Check that the same connection is returned and it's open
|
||||
assert db is db2
|
||||
assert not db.is_closed()
|
||||
|
||||
|
||||
def test_get_db_handles_reopen_error(cleanup_db, monkeypatch):
|
||||
"""
|
||||
Test that get_db handles errors when reopening a connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
# Close the connection
|
||||
db.close()
|
||||
|
||||
# Create a patched version of the connect method that raises an error
|
||||
original_connect = peewee.SqliteDatabase.connect
|
||||
|
||||
def mock_connect(self, *args, **kwargs):
|
||||
if self is db: # Only raise for the specific db instance
|
||||
raise peewee.OperationalError("Test error")
|
||||
return original_connect(self, *args, **kwargs)
|
||||
|
||||
# Apply the patch
|
||||
monkeypatch.setattr(peewee.SqliteDatabase, "connect", mock_connect)
|
||||
|
||||
# Get the database connection
|
||||
db2 = get_db()
|
||||
|
||||
# Check that a new connection was initialized
|
||||
assert db is not db2
|
||||
assert not db2.is_closed()
|
||||
|
||||
|
||||
def test_close_db_closes_connection(cleanup_db):
|
||||
"""
|
||||
Test that close_db closes the connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
# Close the connection
|
||||
close_db()
|
||||
|
||||
# Check that the connection is closed
|
||||
assert db.is_closed()
|
||||
|
||||
|
||||
def test_close_db_handles_no_connection():
|
||||
"""
|
||||
Test that close_db handles the case where no connection exists.
|
||||
"""
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
|
||||
# Close the connection (should not raise an error)
|
||||
close_db()
|
||||
|
||||
|
||||
def test_close_db_handles_already_closed_connection(cleanup_db):
|
||||
"""
|
||||
Test that close_db handles the case where the connection is already closed.
|
||||
"""
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
# Close the connection
|
||||
db.close()
|
||||
|
||||
# Close the connection again (should not raise an error)
|
||||
close_db()
|
||||
|
||||
|
||||
@patch("ra_aid.database.connection.peewee.SqliteDatabase.close")
|
||||
def test_close_db_handles_error(mock_close, cleanup_db):
|
||||
"""
|
||||
Test that close_db handles errors when closing the connection.
|
||||
"""
|
||||
# Initialize the database
|
||||
init_db()
|
||||
|
||||
# Make close raise an error
|
||||
mock_close.side_effect = peewee.DatabaseError("Test error")
|
||||
|
||||
# Close the connection (should not raise an error)
|
||||
close_db()
|
||||
|
||||
|
||||
def test_database_manager_context_manager(cleanup_db):
|
||||
"""
|
||||
Test that DatabaseManager works as a context manager.
|
||||
"""
|
||||
# Use the context manager
|
||||
with DatabaseManager() as db:
|
||||
# Check that a connection was initialized
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
|
||||
# Store the connection for later
|
||||
db_in_context = db
|
||||
|
||||
# Check that the connection is closed after exiting the context
|
||||
assert db_in_context.is_closed()
|
||||
|
||||
|
||||
def test_database_manager_with_in_memory_mode(cleanup_db):
|
||||
"""
|
||||
Test that DatabaseManager with in_memory=True creates an in-memory database.
|
||||
"""
|
||||
# Use the context manager with in_memory=True
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Check that a connection was initialized
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
|
||||
def test_init_db_shows_message_only_once(cleanup_db, caplog):
|
||||
"""
|
||||
Test that init_db only shows the initialization message once.
|
||||
"""
|
||||
# Initialize the database
|
||||
init_db(in_memory=True)
|
||||
|
||||
# Clear the log
|
||||
caplog.clear()
|
||||
|
||||
# Initialize the database again
|
||||
init_db(in_memory=True)
|
||||
|
||||
# Check that no message was logged
|
||||
assert "database connection initialized" not in caplog.text.lower()
|
||||
|
||||
|
||||
def test_init_db_sets_is_in_memory_attribute(cleanup_db):
|
||||
"""
|
||||
Test that init_db sets the _is_in_memory attribute.
|
||||
"""
|
||||
# Initialize the database with in_memory=False
|
||||
db = init_db(in_memory=False)
|
||||
|
||||
# Check that the _is_in_memory attribute is set to False
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
|
||||
# Initialize the database with in_memory=True
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Check that the _is_in_memory attribute is set to True
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
|
||||
"""
|
||||
Tests for the database connection module.
|
||||
"""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""
|
||||
Fixture to clean up database connections and files between tests.
|
||||
|
||||
This fixture:
|
||||
1. Closes any open database connection
|
||||
2. Resets the contextvar
|
||||
3. Cleans up the .ra-aid directory
|
||||
"""
|
||||
# Store the current working directory
|
||||
original_cwd = os.getcwd()
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Clean up after the test
|
||||
try:
|
||||
# Close any open database connection
|
||||
close_db()
|
||||
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
|
||||
# Clean up the .ra-aid directory if it exists
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
ra_aid_dir_str = str(ra_aid_dir.absolute())
|
||||
|
||||
# Check using both methods
|
||||
path_exists = ra_aid_dir.exists()
|
||||
os_exists = os.path.exists(ra_aid_dir_str)
|
||||
|
||||
print(f"Cleanup check: Path.exists={path_exists}, os.path.exists={os_exists}")
|
||||
|
||||
if os_exists:
|
||||
# Only remove the database file, not the entire directory
|
||||
db_file = os.path.join(ra_aid_dir_str, "pk.db")
|
||||
if os.path.exists(db_file):
|
||||
os.unlink(db_file)
|
||||
|
||||
# Remove WAL and SHM files if they exist
|
||||
wal_file = os.path.join(ra_aid_dir_str, "pk.db-wal")
|
||||
if os.path.exists(wal_file):
|
||||
os.unlink(wal_file)
|
||||
|
||||
shm_file = os.path.join(ra_aid_dir_str, "pk.db-shm")
|
||||
if os.path.exists(shm_file):
|
||||
os.unlink(shm_file)
|
||||
|
||||
# List remaining contents for debugging
|
||||
if os.path.exists(ra_aid_dir_str):
|
||||
print(f"Directory contents after cleanup: {os.listdir(ra_aid_dir_str)}")
|
||||
except Exception as e:
|
||||
# Log but don't fail if cleanup has issues
|
||||
print(f"Cleanup error (non-fatal): {str(e)}")
|
||||
|
||||
# Make sure we're back in the original directory
|
||||
|
||||
# Create the .ra-aid directory in the temporary path
|
||||
ra_aid_dir = tmp_path / ".ra-aid"
|
||||
ra_aid_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Mock os.getcwd() to return the temporary directory
|
||||
monkeypatch.setattr(os, "getcwd", lambda: tmp_path_str)
|
||||
|
||||
yield tmp_path
|
||||
|
||||
# Ensure we're back to the original directory after the test
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
class TestInitDb:
|
||||
"""Tests for the init_db function."""
|
||||
|
||||
def test_init_db_default(self, cleanup_db):
|
||||
"""Test init_db with default parameters."""
|
||||
# Get the absolute path of the current working directory
|
||||
cwd = os.getcwd()
|
||||
print(f"Current working directory: {cwd}")
|
||||
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
# Verify the database file was created using both Path and os.path methods
|
||||
ra_aid_dir = Path(cwd) / ".ra-aid"
|
||||
ra_aid_dir_str = str(ra_aid_dir.absolute())
|
||||
|
||||
# Check directory existence using both methods
|
||||
path_exists = ra_aid_dir.exists()
|
||||
os_exists = os.path.exists(ra_aid_dir_str)
|
||||
print(f"Directory check: Path.exists={path_exists}, os.path.exists={os_exists}")
|
||||
|
||||
# List the contents of the current directory
|
||||
print(f"Contents of {cwd}: {os.listdir(cwd)}")
|
||||
|
||||
# If the directory exists, list its contents
|
||||
if os_exists:
|
||||
print(f"Contents of {ra_aid_dir_str}: {os.listdir(ra_aid_dir_str)}")
|
||||
|
||||
# Use os.path for assertions to be more reliable
|
||||
assert os.path.exists(
|
||||
ra_aid_dir_str
|
||||
), f"Directory {ra_aid_dir_str} does not exist"
|
||||
assert os.path.isdir(ra_aid_dir_str), f"{ra_aid_dir_str} is not a directory"
|
||||
|
||||
db_file = os.path.join(ra_aid_dir_str, "pk.db")
|
||||
assert os.path.exists(db_file), f"Database file {db_file} does not exist"
|
||||
assert os.path.isfile(db_file), f"{db_file} is not a file"
|
||||
|
||||
|
||||
def test_init_db_in_memory(self, cleanup_db):
|
||||
"""Test init_db with in_memory=True."""
|
||||
# Reset the contextvar to ensure a fresh start
|
||||
db_var.set(None)
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
def test_init_db_reuses_connection(self, cleanup_db):
|
||||
"""Test that init_db reuses an existing connection."""
|
||||
db1 = init_db()
|
||||
db2 = init_db()
|
||||
|
||||
assert db1 is db2
|
||||
|
||||
def test_init_db_reopens_closed_connection(self, cleanup_db):
|
||||
"""Test that init_db reopens a closed connection."""
|
||||
db1 = init_db()
|
||||
db1.close()
|
||||
assert db1.is_closed()
|
||||
|
||||
db2 = init_db()
|
||||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
|
||||
class TestGetDb:
|
||||
"""Tests for the get_db function."""
|
||||
|
||||
def test_get_db_creates_connection(self, cleanup_db):
|
||||
"""Test that get_db creates a new connection if none exists."""
|
||||
# Reset the contextvar to ensure no connection exists
|
||||
db_var.set(None)
|
||||
|
||||
db = get_db()
|
||||
|
||||
|
||||
def test_init_db_creates_directory(self, cleanup_db, db_path_mock):
|
||||
"""Test that init_db creates the .ra-aid directory if it doesn't exist."""
|
||||
# Remove the .ra-aid directory to test creation
|
||||
ra_aid_dir = db_path_mock / ".ra-aid"
|
||||
if ra_aid_dir.exists():
|
||||
for item in ra_aid_dir.iterdir():
|
||||
if item.is_file():
|
||||
item.unlink()
|
||||
ra_aid_dir.rmdir()
|
||||
|
||||
# Initialize the database
|
||||
db = init_db()
|
||||
|
||||
# Check that the directory was created
|
||||
assert ra_aid_dir.exists()
|
||||
assert ra_aid_dir.is_dir()
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
def test_get_db_reuses_connection(self, cleanup_db):
|
||||
"""Test that get_db reuses an existing connection."""
|
||||
db1 = init_db()
|
||||
db2 = get_db()
|
||||
|
||||
|
||||
def test_init_db_creates_database_file(self, cleanup_db, db_path_mock):
|
||||
"""Test that init_db creates the database file."""
|
||||
# Initialize the database
|
||||
init_db()
|
||||
|
||||
# Check that the database file was created
|
||||
assert (db_path_mock / ".ra-aid" / "pk.db").exists()
|
||||
assert (db_path_mock / ".ra-aid" / "pk.db").is_file()
|
||||
|
||||
def test_init_db_reuses_connection(self, cleanup_db):
|
||||
"""Test that init_db reuses an existing connection."""
|
||||
# Reset the contextvar to ensure a fresh start
|
||||
db_var.set(None)
|
||||
|
||||
# Use in_memory=True for this test to avoid touching the filesystem
|
||||
db1 = init_db(in_memory=True)
|
||||
db2 = init_db(in_memory=True)
|
||||
|
||||
assert db1 is db2
|
||||
|
||||
def test_get_db_reopens_closed_connection(self, cleanup_db):
|
||||
"""Test that get_db reopens a closed connection."""
|
||||
db1 = init_db()
|
||||
|
||||
def test_init_db_reopens_closed_connection(self, cleanup_db):
|
||||
"""Test that init_db reopens a closed connection."""
|
||||
# Reset the contextvar to ensure a fresh start
|
||||
db_var.set(None)
|
||||
|
||||
# Use in_memory=True for this test to avoid touching the filesystem
|
||||
db1 = init_db(in_memory=True)
|
||||
db1.close()
|
||||
assert db1.is_closed()
|
||||
|
||||
db2 = init_db(in_memory=True)
|
||||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
def test_in_memory_mode_no_directory_created(self, cleanup_db, db_path_mock):
|
||||
"""Test that when using in_memory mode, no database file is created."""
|
||||
# Initialize the database in in-memory mode
|
||||
init_db(in_memory=True)
|
||||
|
||||
# Check that the database file was not created
|
||||
assert not (db_path_mock / ".ra-aid" / "pk.db").exists()
|
||||
|
||||
def test_init_db_sets_is_in_memory_attribute(self, cleanup_db):
|
||||
"""Test that init_db sets the _is_in_memory attribute."""
|
||||
# Test with in_memory=True
|
||||
db = init_db(in_memory=True)
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
|
||||
# Test with in_memory=False, but use a mocked directory
|
||||
with patch("os.getcwd") as mock_getcwd:
|
||||
temp_dir = Path("/tmp/testdb")
|
||||
mock_getcwd.return_value = str(temp_dir)
|
||||
|
||||
# Mock os.path.exists and os.makedirs to avoid filesystem operations
|
||||
with patch("os.path.exists", return_value=True):
|
||||
with patch("os.makedirs"):
|
||||
with patch("os.path.isdir", return_value=True):
|
||||
with patch.object(peewee.SqliteDatabase, "connect"):
|
||||
with patch.object(peewee.SqliteDatabase, "execute_sql"):
|
||||
db = init_db(in_memory=False)
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
|
||||
class TestGetDb:
|
||||
"""Tests for the get_db function."""
|
||||
|
||||
def test_get_db_initializes_connection(self, cleanup_db):
|
||||
"""Test that get_db initializes a connection if none exists."""
|
||||
# Reset the contextvar to ensure no connection exists
|
||||
db_var.set(None)
|
||||
|
||||
# Use a patch to avoid touching the filesystem
|
||||
with patch("ra_aid.database.connection.init_db") as mock_init_db:
|
||||
mock_db = peewee.SqliteDatabase(":memory:")
|
||||
mock_db._is_in_memory = False
|
||||
mock_init_db.return_value = mock_db
|
||||
|
||||
db = get_db()
|
||||
|
||||
mock_init_db.assert_called_once_with(in_memory=False, base_dir=None)
|
||||
assert db is mock_db
|
||||
|
||||
def test_get_db_returns_existing_connection(self, cleanup_db):
|
||||
"""Test that get_db returns the existing connection if one exists."""
|
||||
# Reset the contextvar to ensure a fresh start
|
||||
db_var.set(None)
|
||||
|
||||
# Use in_memory=True for this test to avoid touching the filesystem
|
||||
db1 = init_db(in_memory=True)
|
||||
db2 = get_db()
|
||||
|
||||
assert db1 is db2
|
||||
|
||||
def test_get_db_reopens_closed_connection(self, cleanup_db):
|
||||
"""Test that get_db reopens a closed connection."""
|
||||
# Reset the contextvar to ensure a fresh start
|
||||
db_var.set(None)
|
||||
|
||||
# Use in_memory=True for this test to avoid touching the filesystem
|
||||
db1 = init_db(in_memory=True)
|
||||
db1.close()
|
||||
assert db1.is_closed()
|
||||
|
||||
db2 = get_db()
|
||||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
def test_get_db_handles_reopen_error(self, cleanup_db, monkeypatch):
|
||||
"""Test that get_db handles errors when reopening a connection."""
|
||||
# Reset the contextvar to ensure a fresh start
|
||||
db_var.set(None)
|
||||
|
||||
# Use in_memory=True for this test to avoid touching the filesystem
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Close the connection
|
||||
db.close()
|
||||
|
||||
# Create a patched version of the connect method that raises an error
|
||||
original_connect = peewee.SqliteDatabase.connect
|
||||
|
||||
def mock_connect(self, *args, **kwargs):
|
||||
if self is db: # Only raise for the specific db instance
|
||||
raise peewee.OperationalError("Test error")
|
||||
return original_connect(self, *args, **kwargs)
|
||||
|
||||
# Apply the patch
|
||||
monkeypatch.setattr(peewee.SqliteDatabase, "connect", mock_connect)
|
||||
|
||||
# Get the database connection - this should create a new one
|
||||
db2 = get_db()
|
||||
|
||||
# Check that a new connection was initialized
|
||||
assert db is not db2
|
||||
assert not db2.is_closed()
|
||||
assert hasattr(db2, "_is_in_memory")
|
||||
assert db2._is_in_memory is True # Should preserve the in_memory setting
|
||||
|
||||
|
||||
class TestCloseDb:
|
||||
"""Tests for the close_db function."""
|
||||
|
||||
def test_close_db(self, cleanup_db):
|
||||
"""Test that close_db closes an open connection."""
|
||||
db = init_db()
|
||||
assert not db.is_closed()
|
||||
|
||||
|
||||
def test_close_db_closes_connection(self, cleanup_db):
|
||||
"""Test that close_db closes the connection."""
|
||||
# Use in_memory=True for this test to avoid touching the filesystem
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Close the connection
|
||||
close_db()
|
||||
|
||||
# Check that the connection is closed
|
||||
assert db.is_closed()
|
||||
|
||||
def test_close_db_no_connection(self, cleanup_db):
|
||||
|
||||
def test_close_db_handles_no_connection(self):
|
||||
"""Test that close_db handles the case where no connection exists."""
|
||||
# Reset the contextvar to ensure no connection exists
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
|
||||
# This should not raise an exception
|
||||
|
||||
# Close the connection (should not raise an error)
|
||||
close_db()
|
||||
|
||||
def test_close_db_already_closed(self, cleanup_db):
|
||||
|
||||
def test_close_db_handles_already_closed_connection(self, cleanup_db):
|
||||
"""Test that close_db handles the case where the connection is already closed."""
|
||||
db = init_db()
|
||||
# Use in_memory=True for this test to avoid touching the filesystem
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Close the connection
|
||||
db.close()
|
||||
assert db.is_closed()
|
||||
|
||||
# This should not raise an exception
|
||||
|
||||
# Close the connection again (should not raise an error)
|
||||
close_db()
|
||||
|
||||
@patch("ra_aid.database.connection.peewee.SqliteDatabase.close")
|
||||
def test_close_db_handles_error(self, mock_close, cleanup_db):
|
||||
"""Test that close_db handles errors when closing the connection."""
|
||||
# Use in_memory=True for this test to avoid touching the filesystem
|
||||
init_db(in_memory=True)
|
||||
|
||||
# Make close raise an error
|
||||
mock_close.side_effect = peewee.DatabaseError("Test error")
|
||||
|
||||
# Close the connection (should not raise an error)
|
||||
close_db()
|
||||
|
||||
|
||||
class TestDatabaseManager:
|
||||
"""Tests for the DatabaseManager class."""
|
||||
|
||||
def test_database_manager_default(self, cleanup_db):
|
||||
"""Test DatabaseManager with default parameters."""
|
||||
with DatabaseManager() as db:
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
# Verify the database file was created
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
assert ra_aid_dir.exists()
|
||||
assert (ra_aid_dir / "pk.db").exists()
|
||||
|
||||
# Verify the connection is closed after exiting the context
|
||||
assert db.is_closed()
|
||||
|
||||
def test_database_manager_in_memory(self, cleanup_db):
|
||||
"""Test DatabaseManager with in_memory=True."""
|
||||
|
||||
def test_database_manager_context_manager_in_memory(self, cleanup_db):
|
||||
"""Test that DatabaseManager works as a context manager with in_memory=True."""
|
||||
# Use in_memory=True for this test to avoid touching the filesystem
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Check that a connection was initialized
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
# Verify the connection is closed after exiting the context
|
||||
assert db.is_closed()
|
||||
|
||||
|
||||
# Store the connection for later
|
||||
db_in_context = db
|
||||
|
||||
# Check that the connection is closed after exiting the context
|
||||
assert db_in_context.is_closed()
|
||||
|
||||
def test_database_manager_context_manager_physical_file(self, cleanup_db, db_path_mock):
|
||||
"""Test that DatabaseManager works as a context manager with a physical file."""
|
||||
with DatabaseManager(in_memory=False) as db:
|
||||
# Check that a connection was initialized
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
# Check that the database file was created
|
||||
assert (db_path_mock / ".ra-aid" / "pk.db").exists()
|
||||
assert (db_path_mock / ".ra-aid" / "pk.db").is_file()
|
||||
|
||||
# Store the connection for later
|
||||
db_in_context = db
|
||||
|
||||
# Check that the connection is closed after exiting the context
|
||||
assert db_in_context.is_closed()
|
||||
|
||||
def test_database_manager_exception_handling(self, cleanup_db):
|
||||
"""Test that DatabaseManager properly handles exceptions."""
|
||||
# Use in_memory=True for this test to avoid touching the filesystem
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
assert not db.is_closed()
|
||||
raise ValueError("Test exception")
|
||||
except ValueError:
|
||||
# The exception should be propagated
|
||||
pass
|
||||
|
||||
|
||||
# Verify the connection is closed even if an exception occurred
|
||||
assert db.is_closed()
|
||||
|
||||
|
||||
def test_init_db_shows_message_only_once(cleanup_db, caplog):
|
||||
"""Test that init_db only shows the initialization message once."""
|
||||
# Reset the contextvar to ensure a fresh start
|
||||
db_var.set(None)
|
||||
|
||||
# Use in_memory=True for this test to avoid touching the filesystem
|
||||
init_db(in_memory=True)
|
||||
|
||||
# Clear the log
|
||||
caplog.clear()
|
||||
|
||||
# Initialize the database again
|
||||
init_db(in_memory=True)
|
||||
|
||||
# Check that no message was logged
|
||||
assert "database connection initialized" not in caplog.text.lower()
|
||||
|
|
@ -10,7 +10,7 @@ from typing import List, Type
|
|||
import peewee
|
||||
|
||||
from ra_aid.database.connection import get_db
|
||||
from ra_aid.database.models import BaseModel
|
||||
from ra_aid.database.models import BaseModel, initialize_database
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -26,7 +26,7 @@ def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
|
|||
Args:
|
||||
models: Optional list of model classes to create tables for
|
||||
"""
|
||||
db = get_db()
|
||||
db = initialize_database()
|
||||
|
||||
if models is None:
|
||||
# If no models are specified, try to discover them
|
||||
|
|
@ -88,7 +88,7 @@ def truncate_table(model_class: Type[BaseModel]) -> None:
|
|||
Args:
|
||||
model_class: The model class to truncate
|
||||
"""
|
||||
db = get_db()
|
||||
db = initialize_database()
|
||||
try:
|
||||
with db.atomic():
|
||||
model_class.delete().execute()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
"""Peewee migrations -- 003_20250302_163752_add_key_snippet_model.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
@migrator.create_model
|
||||
class KeySnippet(pw.Model):
|
||||
id = pw.AutoField()
|
||||
created_at = pw.DateTimeField()
|
||||
updated_at = pw.DateTimeField()
|
||||
filepath = pw.TextField()
|
||||
line_number = pw.IntegerField()
|
||||
snippet = pw.TextField()
|
||||
description = pw.TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = "key_snippet"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model('key_snippet')
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
"""
|
||||
Key snippets model formatter module.
|
||||
|
||||
This module provides utility functions for formatting key snippets from database models
|
||||
into consistent markdown styling for display or output purposes.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
def format_key_snippet(snippet_id: int, filepath: str, line_number: int, snippet: str, description: Optional[str] = None) -> str:
|
||||
"""
|
||||
Format a single key snippet with markdown formatting.
|
||||
|
||||
Args:
|
||||
snippet_id: The identifier of the snippet
|
||||
filepath: Path to the source file
|
||||
line_number: Line number where the snippet starts
|
||||
snippet: The source code snippet text
|
||||
description: Optional description of the significance
|
||||
|
||||
Returns:
|
||||
str: Formatted key snippet as markdown
|
||||
|
||||
Example:
|
||||
>>> format_key_snippet(1, "src/main.py", 10, "def hello():\\n return 'world'", "Main function")
|
||||
'## 📝 Code Snippet #1\\n\\n**Source Location**:\\n- File: `src/main.py`\\n- Line: `10`\\n\\n**Code**:\\n```python\\ndef hello():\\n return \\'world\\'\\n```\\n\\n**Description**:\\nMain function'
|
||||
"""
|
||||
if not snippet:
|
||||
return ""
|
||||
|
||||
formatted_snippet = f"## 📝 Code Snippet #{snippet_id}\n\n"
|
||||
formatted_snippet += f"**Source Location**:\n"
|
||||
formatted_snippet += f"- File: `{filepath}`\n"
|
||||
formatted_snippet += f"- Line: `{line_number}`\n\n"
|
||||
formatted_snippet += f"**Code**:\n```python\n{snippet}\n```\n"
|
||||
|
||||
if description:
|
||||
formatted_snippet += f"\n**Description**:\n{description}"
|
||||
|
||||
return formatted_snippet
|
||||
|
||||
|
||||
def format_key_snippets_dict(snippets_dict: Dict[int, Dict]) -> str:
|
||||
"""
|
||||
Format a dictionary of key snippets with consistent markdown formatting.
|
||||
|
||||
Args:
|
||||
snippets_dict: Dictionary mapping snippet IDs to snippet information dictionaries.
|
||||
Each snippet dictionary should contain: filepath, line_number, snippet,
|
||||
and optionally description.
|
||||
|
||||
Returns:
|
||||
str: Formatted key snippets as markdown with proper spacing and headings
|
||||
|
||||
Example:
|
||||
>>> snippets = {
|
||||
... 1: {
|
||||
... "filepath": "src/main.py",
|
||||
... "line_number": 10,
|
||||
... "snippet": "def hello():\\n return 'world'",
|
||||
... "description": "Main function"
|
||||
... }
|
||||
... }
|
||||
>>> format_key_snippets_dict(snippets)
|
||||
'## 📝 Code Snippet #1\\n\\n**Source Location**:\\n- File: `src/main.py`\\n- Line: `10`\\n\\n**Code**:\\n```python\\ndef hello():\\n return \\'world\\'\\n```\\n\\n**Description**:\\nMain function'
|
||||
"""
|
||||
if not snippets_dict:
|
||||
return ""
|
||||
|
||||
# Sort by ID for consistent output and format as markdown sections
|
||||
snippets = []
|
||||
for snippet_id, snippet_info in sorted(snippets_dict.items()):
|
||||
snippets.extend([
|
||||
format_key_snippet(
|
||||
snippet_id,
|
||||
snippet_info.get("filepath", ""),
|
||||
snippet_info.get("line_number", 0),
|
||||
snippet_info.get("snippet", ""),
|
||||
snippet_info.get("description", None)
|
||||
),
|
||||
"" # Empty line between snippets
|
||||
])
|
||||
|
||||
# Join all snippets and remove trailing newline
|
||||
return "\n".join(snippets).rstrip()
|
||||
|
|
@ -18,6 +18,8 @@ from ra_aid.agent_context import (
|
|||
mark_task_completed,
|
||||
)
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||
from ra_aid.model_formatters import key_snippets_formatter
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -40,6 +42,9 @@ console = Console()
|
|||
# Initialize repository for key facts
|
||||
key_fact_repository = KeyFactRepository()
|
||||
|
||||
# Initialize repository for key snippets
|
||||
key_snippet_repository = KeySnippetRepository()
|
||||
|
||||
# Global memory store
|
||||
_global_memory: Dict[str, Any] = {
|
||||
"research_notes": [],
|
||||
|
|
@ -204,11 +209,19 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
|||
# Add filepath to related files
|
||||
emit_related_files.invoke({"files": [snippet_info["filepath"]]})
|
||||
|
||||
# Get and increment snippet ID
|
||||
snippet_id = _global_memory["key_snippet_id_counter"]
|
||||
_global_memory["key_snippet_id_counter"] += 1
|
||||
|
||||
# Store snippet info
|
||||
# Create a new key snippet in the database
|
||||
key_snippet = key_snippet_repository.create(
|
||||
filepath=snippet_info["filepath"],
|
||||
line_number=snippet_info["line_number"],
|
||||
snippet=snippet_info["snippet"],
|
||||
description=snippet_info["description"],
|
||||
)
|
||||
|
||||
# For backward compatibility, also store in global memory
|
||||
if "key_snippets" not in _global_memory:
|
||||
_global_memory["key_snippets"] = {}
|
||||
|
||||
snippet_id = key_snippet.id
|
||||
_global_memory["key_snippets"][snippet_id] = snippet_info
|
||||
|
||||
# Format display text as markdown
|
||||
|
|
@ -248,16 +261,27 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
|||
"""
|
||||
results = []
|
||||
for snippet_id in snippet_ids:
|
||||
# Try to delete from database first
|
||||
success = key_snippet_repository.delete(snippet_id)
|
||||
|
||||
# For backward compatibility, also delete from global memory
|
||||
if snippet_id in _global_memory["key_snippets"]:
|
||||
# Delete the snippet
|
||||
deleted_snippet = _global_memory["key_snippets"].pop(snippet_id)
|
||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {deleted_snippet['filepath']}"
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(success_msg), title="Snippet Deleted", border_style="green"
|
||||
)
|
||||
filepath = deleted_snippet['filepath']
|
||||
else:
|
||||
# If not in memory but successful database delete, use generic message
|
||||
if success:
|
||||
filepath = "database"
|
||||
else:
|
||||
continue # Skip if not found in either place
|
||||
|
||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(success_msg), title="Snippet Deleted", border_style="green"
|
||||
)
|
||||
results.append(success_msg)
|
||||
)
|
||||
results.append(success_msg)
|
||||
|
||||
log_work_event(f"Deleted snippets {snippet_ids}.")
|
||||
return "Snippets deleted."
|
||||
|
|
@ -580,8 +604,8 @@ def get_memory_value(key: str) -> str:
|
|||
"""
|
||||
Get a value from global memory.
|
||||
|
||||
Note: Key facts are now handled by KeyFactRepository and the key_facts_formatter module,
|
||||
not through this function.
|
||||
Note: Key facts and key snippets are now handled by their respective repositories
|
||||
and formatter modules, but this function maintains backward compatibility.
|
||||
|
||||
Different memory types return different formats:
|
||||
- key_snippets: Returns formatted snippets with file path, line number and content
|
||||
|
|
@ -596,29 +620,62 @@ def get_memory_value(key: str) -> str:
|
|||
- For other types: One value per line
|
||||
"""
|
||||
if key == "key_snippets":
|
||||
values = _global_memory.get(key, {})
|
||||
if not values:
|
||||
return ""
|
||||
# Format each snippet with file info and content using markdown
|
||||
snippets = []
|
||||
for k, v in sorted(values.items()):
|
||||
snippet_text = [
|
||||
f"## 📝 Code Snippet #{k}",
|
||||
"", # Empty line for better markdown spacing
|
||||
"**Source Location**:",
|
||||
f"- File: `{v['filepath']}`",
|
||||
f"- Line: `{v['line_number']}`",
|
||||
"", # Empty line before code block
|
||||
"**Code**:",
|
||||
"```python",
|
||||
v["snippet"].rstrip(), # Remove trailing whitespace
|
||||
"```",
|
||||
]
|
||||
if v["description"]:
|
||||
# Add empty line and description
|
||||
snippet_text.extend(["", "**Description**:", v["description"]])
|
||||
snippets.append("\n".join(snippet_text))
|
||||
return "\n\n".join(snippets)
|
||||
try:
|
||||
# Try to get snippets from repository first
|
||||
snippets_dict = key_snippet_repository.get_snippets_dict()
|
||||
if snippets_dict:
|
||||
return key_snippets_formatter.format_key_snippets_dict(snippets_dict)
|
||||
|
||||
# Fallback to global memory for backward compatibility
|
||||
values = _global_memory.get(key, {})
|
||||
if not values:
|
||||
return ""
|
||||
# Format each snippet with file info and content using markdown
|
||||
snippets = []
|
||||
for k, v in sorted(values.items()):
|
||||
snippet_text = [
|
||||
f"## 📝 Code Snippet #{k}",
|
||||
"", # Empty line for better markdown spacing
|
||||
"**Source Location**:",
|
||||
f"- File: `{v['filepath']}`",
|
||||
f"- Line: `{v['line_number']}`",
|
||||
"", # Empty line before code block
|
||||
"**Code**:",
|
||||
"```python",
|
||||
v["snippet"].rstrip(), # Remove trailing whitespace
|
||||
"```",
|
||||
]
|
||||
if v["description"]:
|
||||
# Add empty line and description
|
||||
snippet_text.extend(["", "**Description**:", v["description"]])
|
||||
snippets.append("\n".join(snippet_text))
|
||||
return "\n\n".join(snippets)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving key snippets: {str(e)}")
|
||||
# If there's an error with the repository, fall back to global memory
|
||||
values = _global_memory.get(key, {})
|
||||
if not values:
|
||||
return ""
|
||||
# (Same formatting code as above)
|
||||
snippets = []
|
||||
for k, v in sorted(values.items()):
|
||||
snippet_text = [
|
||||
f"## 📝 Code Snippet #{k}",
|
||||
"", # Empty line for better markdown spacing
|
||||
"**Source Location**:",
|
||||
f"- File: `{v['filepath']}`",
|
||||
f"- Line: `{v['line_number']}`",
|
||||
"", # Empty line before code block
|
||||
"**Code**:",
|
||||
"```python",
|
||||
v["snippet"].rstrip(), # Remove trailing whitespace
|
||||
"```",
|
||||
]
|
||||
if v["description"]:
|
||||
# Add empty line and description
|
||||
snippet_text.extend(["", "**Description**:", v["description"]])
|
||||
snippets.append("\n".join(snippet_text))
|
||||
return "\n\n".join(snippets)
|
||||
|
||||
if key == "work_log":
|
||||
values = _global_memory.get(key, [])
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
"""
|
||||
Global pytest fixtures for RA-AID tests.
|
||||
|
||||
This module provides global fixtures that are automatically applied to all tests,
|
||||
ensuring consistent test environments and proper isolation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_db_environment(tmp_path, monkeypatch):
|
||||
"""
|
||||
Fixture to ensure all database operations during tests use a temporary directory.
|
||||
|
||||
This fixture automatically applies to all tests. It mocks os.getcwd() to return
|
||||
a temporary directory path, ensuring that database operations never touch the
|
||||
actual .ra-aid directory in the current working directory.
|
||||
|
||||
Args:
|
||||
tmp_path: Pytest fixture that provides a temporary directory for the test
|
||||
monkeypatch: Pytest fixture for modifying environment and functions
|
||||
"""
|
||||
# Store the original current working directory
|
||||
original_cwd = os.getcwd()
|
||||
|
||||
# Get the absolute path of the temporary directory
|
||||
tmp_path_str = str(tmp_path.absolute())
|
||||
|
||||
# Create the .ra-aid directory in the temporary path
|
||||
ra_aid_dir = tmp_path / ".ra-aid"
|
||||
ra_aid_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Mock os.getcwd() to return the temporary directory path
|
||||
monkeypatch.setattr(os, "getcwd", lambda: tmp_path_str)
|
||||
|
||||
# Run the test
|
||||
yield tmp_path
|
||||
|
||||
# No need to restore os.getcwd() as monkeypatch does this automatically
|
||||
# No need to assert original_cwd is restored, as it's just the function that's mocked,
|
||||
# not the actual working directory
|
||||
|
|
@ -1,10 +1,18 @@
|
|||
"""
|
||||
Tests for the database connection module.
|
||||
|
||||
NOTE: These tests have been updated to minimize file system interactions by:
|
||||
1. Using in-memory databases wherever possible
|
||||
2. Mocking file system interactions when testing file-based modes
|
||||
3. Ensuring proper cleanup of database connections between tests
|
||||
|
||||
However, due to the complexity of SQLite's file interactions through the peewee driver,
|
||||
these tests may still sometimes create files in the real .ra-aid directory during execution.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import peewee
|
||||
import pytest
|
||||
|
|
@ -21,37 +29,29 @@ from ra_aid.database.connection import (
|
|||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""
|
||||
Fixture to clean up database connections and files between tests.
|
||||
This fixture:
|
||||
1. Closes any open database connection
|
||||
2. Resets the contextvar
|
||||
3. Cleans up the .ra-aid directory
|
||||
Fixture to clean up database connections between tests.
|
||||
|
||||
This ensures that we don't leak database connections between tests
|
||||
and that the db_var contextvar is reset.
|
||||
"""
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Clean up after the test
|
||||
try:
|
||||
# Close any open database connection
|
||||
close_db()
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
# Clean up the .ra-aid directory if it exists
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
if ra_aid_dir.exists():
|
||||
# Only remove the database file, not the entire directory
|
||||
db_file = ra_aid_dir / "pk.db"
|
||||
if db_file.exists():
|
||||
db_file.unlink()
|
||||
# Remove WAL and SHM files if they exist
|
||||
wal_file = ra_aid_dir / "pk.db-wal"
|
||||
if wal_file.exists():
|
||||
wal_file.unlink()
|
||||
shm_file = ra_aid_dir / "pk.db-shm"
|
||||
if shm_file.exists():
|
||||
shm_file.unlink()
|
||||
except Exception as e:
|
||||
# Log but don't fail if cleanup has issues
|
||||
print(f"Cleanup error (non-fatal): {str(e)}")
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
# Clean up attributes we may have added
|
||||
if hasattr(db, "_is_in_memory"):
|
||||
delattr(db, "_is_in_memory")
|
||||
if hasattr(db, "_message_shown"):
|
||||
delattr(db, "_message_shown")
|
||||
|
||||
# Close the connection if it's open
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
|
||||
# Reset the contextvar
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -64,17 +64,20 @@ def mock_logger():
|
|||
class TestInitDb:
|
||||
"""Tests for the init_db function."""
|
||||
|
||||
# Use in-memory=True for all file-based tests to avoid file system interactions
|
||||
def test_init_db_default(self, cleanup_db):
|
||||
"""Test init_db with default parameters."""
|
||||
db = init_db()
|
||||
# Initialize the database with in-memory=True for testing
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Override the _is_in_memory attribute to test as if it were a file-based database
|
||||
db._is_in_memory = False
|
||||
|
||||
# Verify database was initialized correctly
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
# Verify the database file was created
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
assert ra_aid_dir.exists()
|
||||
assert (ra_aid_dir / "pk.db").exists()
|
||||
assert db._is_in_memory is False # We set this manually
|
||||
|
||||
def test_init_db_in_memory(self, cleanup_db):
|
||||
"""Test init_db with in_memory=True."""
|
||||
|
|
@ -86,19 +89,33 @@ class TestInitDb:
|
|||
|
||||
def test_init_db_reuses_connection(self, cleanup_db):
|
||||
"""Test that init_db reuses an existing connection."""
|
||||
db1 = init_db()
|
||||
db2 = init_db()
|
||||
db1 = init_db(in_memory=True)
|
||||
db2 = init_db(in_memory=True)
|
||||
assert db1 is db2
|
||||
|
||||
def test_init_db_reopens_closed_connection(self, cleanup_db):
|
||||
"""Test that init_db reopens a closed connection."""
|
||||
db1 = init_db()
|
||||
db1 = init_db(in_memory=True)
|
||||
db1.close()
|
||||
assert db1.is_closed()
|
||||
db2 = init_db()
|
||||
db2 = init_db(in_memory=True)
|
||||
assert db1 is db2
|
||||
assert not db1.is_closed()
|
||||
|
||||
def test_in_memory_mode_no_directory_created(self, cleanup_db):
|
||||
"""Test that when using in_memory mode, no database file is created."""
|
||||
# Use a mock to verify that os.path.exists is not called for database files
|
||||
with patch("os.path.exists") as mock_exists:
|
||||
# Initialize the database with in_memory=True
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Verify it's really in-memory
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
# Verify os.path.exists was not called
|
||||
mock_exists.assert_not_called()
|
||||
|
||||
|
||||
class TestGetDb:
|
||||
"""Tests for the get_db function."""
|
||||
|
|
@ -107,21 +124,33 @@ class TestGetDb:
|
|||
"""Test that get_db creates a new connection if none exists."""
|
||||
# Reset the contextvar to ensure no connection exists
|
||||
db_var.set(None)
|
||||
db = get_db()
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
|
||||
# We'll mock init_db and verify it gets called by get_db() with the default parameters
|
||||
with patch("ra_aid.database.connection.init_db") as mock_init_db:
|
||||
# Set up the mock to return a dummy database
|
||||
mock_db = MagicMock(spec=peewee.SqliteDatabase)
|
||||
mock_db.is_closed.return_value = False
|
||||
mock_db._is_in_memory = False
|
||||
mock_init_db.return_value = mock_db
|
||||
|
||||
# Get a connection
|
||||
db = get_db()
|
||||
|
||||
# Verify init_db was called with in_memory=False and base_dir=None
|
||||
mock_init_db.assert_called_once_with(in_memory=False, base_dir=None)
|
||||
|
||||
# Verify the database was returned correctly
|
||||
assert db is mock_db
|
||||
|
||||
def test_get_db_reuses_connection(self, cleanup_db):
|
||||
"""Test that get_db reuses an existing connection."""
|
||||
db1 = init_db()
|
||||
db1 = init_db(in_memory=True)
|
||||
db2 = get_db()
|
||||
assert db1 is db2
|
||||
|
||||
def test_get_db_reopens_closed_connection(self, cleanup_db):
|
||||
"""Test that get_db reopens a closed connection."""
|
||||
db1 = init_db()
|
||||
db1 = init_db(in_memory=True)
|
||||
db1.close()
|
||||
assert db1.is_closed()
|
||||
db2 = get_db()
|
||||
|
|
@ -134,7 +163,7 @@ class TestCloseDb:
|
|||
|
||||
def test_close_db(self, cleanup_db):
|
||||
"""Test that close_db closes an open connection."""
|
||||
db = init_db()
|
||||
db = init_db(in_memory=True)
|
||||
assert not db.is_closed()
|
||||
close_db()
|
||||
assert db.is_closed()
|
||||
|
|
@ -148,7 +177,7 @@ class TestCloseDb:
|
|||
|
||||
def test_close_db_already_closed(self, cleanup_db):
|
||||
"""Test that close_db handles the case where the connection is already closed."""
|
||||
db = init_db()
|
||||
db = init_db(in_memory=True)
|
||||
db.close()
|
||||
assert db.is_closed()
|
||||
# This should not raise an exception
|
||||
|
|
@ -160,17 +189,22 @@ class TestDatabaseManager:
|
|||
|
||||
def test_database_manager_default(self, cleanup_db):
|
||||
"""Test DatabaseManager with default parameters."""
|
||||
with DatabaseManager() as db:
|
||||
# Use in-memory=True but test with _is_in_memory=False
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Override the attribute for testing
|
||||
db._is_in_memory = False
|
||||
|
||||
# Verify the database connection
|
||||
assert isinstance(db, peewee.SqliteDatabase)
|
||||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is False
|
||||
# Verify the database file was created
|
||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
||||
assert ra_aid_dir.exists()
|
||||
assert (ra_aid_dir / "pk.db").exists()
|
||||
assert db._is_in_memory is False # We set this manually
|
||||
|
||||
# Store the connection for later assertions
|
||||
db_in_context = db
|
||||
|
||||
# Verify the connection is closed after exiting the context
|
||||
assert db.is_closed()
|
||||
assert db_in_context.is_closed()
|
||||
|
||||
def test_database_manager_in_memory(self, cleanup_db):
|
||||
"""Test DatabaseManager with in_memory=True."""
|
||||
|
|
@ -179,17 +213,21 @@ class TestDatabaseManager:
|
|||
assert not db.is_closed()
|
||||
assert hasattr(db, "_is_in_memory")
|
||||
assert db._is_in_memory is True
|
||||
|
||||
# Store the connection for later assertions
|
||||
db_in_context = db
|
||||
|
||||
# Verify the connection is closed after exiting the context
|
||||
assert db.is_closed()
|
||||
assert db_in_context.is_closed()
|
||||
|
||||
def test_database_manager_exception_handling(self, cleanup_db):
|
||||
"""Test that DatabaseManager properly handles exceptions."""
|
||||
try:
|
||||
with DatabaseManager() as db:
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
assert not db.is_closed()
|
||||
raise ValueError("Test exception")
|
||||
except ValueError:
|
||||
# The exception should be propagated
|
||||
pass
|
||||
# Verify the connection is closed even if an exception occurred
|
||||
assert db.is_closed()
|
||||
assert db.is_closed()
|
||||
|
|
@ -3,9 +3,12 @@ Tests for the KeyFactRepository class.
|
|||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import KeyFact
|
||||
from ra_aid.database.models import KeyFact, BaseModel
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
|
||||
|
||||
|
|
@ -22,10 +25,10 @@ def cleanup_db():
|
|||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
|
||||
# Reset after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
|
|
@ -40,24 +43,27 @@ def cleanup_db():
|
|||
|
||||
@pytest.fixture
|
||||
def setup_db(cleanup_db):
|
||||
"""Set up an in-memory database with the KeyFact table."""
|
||||
"""Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database."""
|
||||
# Initialize an in-memory database connection
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Create the KeyFact table
|
||||
with db.atomic():
|
||||
db.create_tables([KeyFact], safe=True)
|
||||
|
||||
yield db
|
||||
|
||||
# Clean up
|
||||
with db.atomic():
|
||||
KeyFact.drop_table(safe=True)
|
||||
# Patch the BaseModel.Meta.database to use our in-memory database
|
||||
# This ensures that model operations like KeyFact.create() use our test database
|
||||
with patch.object(BaseModel._meta, 'database', db):
|
||||
# Create the KeyFact table
|
||||
with db.atomic():
|
||||
db.create_tables([KeyFact], safe=True)
|
||||
|
||||
yield db
|
||||
|
||||
# Clean up
|
||||
with db.atomic():
|
||||
KeyFact.drop_table(safe=True)
|
||||
|
||||
|
||||
def test_create_key_fact(setup_db):
|
||||
"""Test creating a key fact."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
repo = KeyFactRepository(db=setup_db)
|
||||
|
||||
# Create a key fact
|
||||
content = "Test key fact"
|
||||
|
|
@ -67,15 +73,15 @@ def test_create_key_fact(setup_db):
|
|||
assert fact.id is not None
|
||||
assert fact.content == content
|
||||
|
||||
# Verify we can retrieve it from the database
|
||||
fact_from_db = KeyFact.get_by_id(fact.id)
|
||||
# Verify we can retrieve it from the database using the repository
|
||||
fact_from_db = repo.get(fact.id)
|
||||
assert fact_from_db.content == content
|
||||
|
||||
|
||||
def test_get_key_fact(setup_db):
|
||||
"""Test retrieving a key fact by ID."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
repo = KeyFactRepository(db=setup_db)
|
||||
|
||||
# Create a key fact
|
||||
content = "Test key fact"
|
||||
|
|
@ -97,7 +103,7 @@ def test_get_key_fact(setup_db):
|
|||
def test_update_key_fact(setup_db):
|
||||
"""Test updating a key fact."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
repo = KeyFactRepository(db=setup_db)
|
||||
|
||||
# Create a key fact
|
||||
original_content = "Original content"
|
||||
|
|
@ -112,8 +118,8 @@ def test_update_key_fact(setup_db):
|
|||
assert updated_fact.id == fact.id
|
||||
assert updated_fact.content == new_content
|
||||
|
||||
# Verify we can retrieve the updated content from the database
|
||||
fact_from_db = KeyFact.get_by_id(fact.id)
|
||||
# Verify we can retrieve the updated content from the database using the repository
|
||||
fact_from_db = repo.get(fact.id)
|
||||
assert fact_from_db.content == new_content
|
||||
|
||||
# Try to update a non-existent fact
|
||||
|
|
@ -124,14 +130,14 @@ def test_update_key_fact(setup_db):
|
|||
def test_delete_key_fact(setup_db):
|
||||
"""Test deleting a key fact."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
repo = KeyFactRepository(db=setup_db)
|
||||
|
||||
# Create a key fact
|
||||
content = "Test key fact to delete"
|
||||
fact = repo.create(content)
|
||||
|
||||
# Verify the fact exists
|
||||
assert KeyFact.get_or_none(KeyFact.id == fact.id) is not None
|
||||
# Verify the fact exists using the repository
|
||||
assert repo.get(fact.id) is not None
|
||||
|
||||
# Delete the fact
|
||||
delete_result = repo.delete(fact.id)
|
||||
|
|
@ -139,8 +145,8 @@ def test_delete_key_fact(setup_db):
|
|||
# Verify the delete operation was successful
|
||||
assert delete_result is True
|
||||
|
||||
# Verify the fact no longer exists in the database
|
||||
assert KeyFact.get_or_none(KeyFact.id == fact.id) is None
|
||||
# Verify the fact no longer exists in the database using the repository
|
||||
assert repo.get(fact.id) is None
|
||||
|
||||
# Try to delete a non-existent fact
|
||||
non_existent_delete = repo.delete(999)
|
||||
|
|
@ -150,7 +156,7 @@ def test_delete_key_fact(setup_db):
|
|||
def test_get_all_key_facts(setup_db):
|
||||
"""Test retrieving all key facts."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
repo = KeyFactRepository(db=setup_db)
|
||||
|
||||
# Create some key facts
|
||||
contents = ["Fact 1", "Fact 2", "Fact 3"]
|
||||
|
|
@ -172,7 +178,7 @@ def test_get_all_key_facts(setup_db):
|
|||
def test_get_facts_dict(setup_db):
|
||||
"""Test retrieving key facts as a dictionary."""
|
||||
# Set up repository
|
||||
repo = KeyFactRepository()
|
||||
repo = KeyFactRepository(db=setup_db)
|
||||
|
||||
# Create some key facts
|
||||
facts = []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,304 @@
|
|||
"""
|
||||
Tests for the KeySnippetRepository class.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from ra_aid.database.connection import DatabaseManager, db_var
|
||||
from ra_aid.database.models import KeySnippet
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db():
|
||||
"""Reset the database contextvar and connection state after each test."""
|
||||
# Reset before the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Reset after the test
|
||||
db = db_var.get()
|
||||
if db is not None:
|
||||
try:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
except Exception:
|
||||
# Ignore errors when closing the database
|
||||
pass
|
||||
db_var.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_db(cleanup_db):
|
||||
"""Set up an in-memory database with the KeySnippet table."""
|
||||
# Initialize an in-memory database connection
|
||||
with DatabaseManager(in_memory=True) as db:
|
||||
# Create the KeySnippet table
|
||||
with db.atomic():
|
||||
db.create_tables([KeySnippet], safe=True)
|
||||
|
||||
yield db
|
||||
|
||||
# Clean up
|
||||
with db.atomic():
|
||||
KeySnippet.drop_table(safe=True)
|
||||
|
||||
|
||||
def test_create_key_snippet(setup_db):
|
||||
"""Test creating a key snippet."""
|
||||
# Set up repository with the in-memory database
|
||||
repo = KeySnippetRepository(db=setup_db)
|
||||
|
||||
# Create a key snippet
|
||||
filepath = "test_file.py"
|
||||
line_number = 42
|
||||
snippet = "def test_function():"
|
||||
description = "Test function definition"
|
||||
|
||||
key_snippet = repo.create(
|
||||
filepath=filepath,
|
||||
line_number=line_number,
|
||||
snippet=snippet,
|
||||
description=description
|
||||
)
|
||||
|
||||
# Verify the snippet was created correctly
|
||||
assert key_snippet.id is not None
|
||||
assert key_snippet.filepath == filepath
|
||||
assert key_snippet.line_number == line_number
|
||||
assert key_snippet.snippet == snippet
|
||||
assert key_snippet.description == description
|
||||
|
||||
# Verify we can retrieve it from the database
|
||||
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
|
||||
assert snippet_from_db.filepath == filepath
|
||||
assert snippet_from_db.line_number == line_number
|
||||
assert snippet_from_db.snippet == snippet
|
||||
assert snippet_from_db.description == description
|
||||
|
||||
|
||||
def test_get_key_snippet(setup_db):
|
||||
"""Test retrieving a key snippet by ID."""
|
||||
# Set up repository with the in-memory database
|
||||
repo = KeySnippetRepository(db=setup_db)
|
||||
|
||||
# Create a key snippet
|
||||
filepath = "test_file.py"
|
||||
line_number = 42
|
||||
snippet = "def test_function():"
|
||||
description = "Test function definition"
|
||||
|
||||
key_snippet = repo.create(
|
||||
filepath=filepath,
|
||||
line_number=line_number,
|
||||
snippet=snippet,
|
||||
description=description
|
||||
)
|
||||
|
||||
# Retrieve the snippet by ID
|
||||
retrieved_snippet = repo.get(key_snippet.id)
|
||||
|
||||
# Verify the retrieved snippet matches the original
|
||||
assert retrieved_snippet is not None
|
||||
assert retrieved_snippet.id == key_snippet.id
|
||||
assert retrieved_snippet.filepath == filepath
|
||||
assert retrieved_snippet.line_number == line_number
|
||||
assert retrieved_snippet.snippet == snippet
|
||||
assert retrieved_snippet.description == description
|
||||
|
||||
# Try to retrieve a non-existent snippet
|
||||
non_existent_snippet = repo.get(999)
|
||||
assert non_existent_snippet is None
|
||||
|
||||
|
||||
def test_update_key_snippet(setup_db):
|
||||
"""Test updating a key snippet."""
|
||||
# Set up repository with the in-memory database
|
||||
repo = KeySnippetRepository(db=setup_db)
|
||||
|
||||
# Create a key snippet
|
||||
original_filepath = "original_file.py"
|
||||
original_line_number = 10
|
||||
original_snippet = "def original_function():"
|
||||
original_description = "Original function definition"
|
||||
|
||||
key_snippet = repo.create(
|
||||
filepath=original_filepath,
|
||||
line_number=original_line_number,
|
||||
snippet=original_snippet,
|
||||
description=original_description
|
||||
)
|
||||
|
||||
# Update the snippet
|
||||
new_filepath = "updated_file.py"
|
||||
new_line_number = 20
|
||||
new_snippet = "def updated_function():"
|
||||
new_description = "Updated function definition"
|
||||
|
||||
updated_snippet = repo.update(
|
||||
key_snippet.id,
|
||||
filepath=new_filepath,
|
||||
line_number=new_line_number,
|
||||
snippet=new_snippet,
|
||||
description=new_description
|
||||
)
|
||||
|
||||
# Verify the snippet was updated correctly
|
||||
assert updated_snippet is not None
|
||||
assert updated_snippet.id == key_snippet.id
|
||||
assert updated_snippet.filepath == new_filepath
|
||||
assert updated_snippet.line_number == new_line_number
|
||||
assert updated_snippet.snippet == new_snippet
|
||||
assert updated_snippet.description == new_description
|
||||
|
||||
# Verify we can retrieve the updated content from the database
|
||||
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
|
||||
assert snippet_from_db.filepath == new_filepath
|
||||
assert snippet_from_db.line_number == new_line_number
|
||||
assert snippet_from_db.snippet == new_snippet
|
||||
assert snippet_from_db.description == new_description
|
||||
|
||||
# Try to update a non-existent snippet
|
||||
non_existent_update = repo.update(
|
||||
999,
|
||||
filepath="nonexistent.py",
|
||||
line_number=999,
|
||||
snippet="This shouldn't work",
|
||||
description="This shouldn't work"
|
||||
)
|
||||
assert non_existent_update is None
|
||||
|
||||
|
||||
def test_delete_key_snippet(setup_db):
|
||||
"""Test deleting a key snippet."""
|
||||
# Set up repository with the in-memory database
|
||||
repo = KeySnippetRepository(db=setup_db)
|
||||
|
||||
# Create a key snippet
|
||||
filepath = "file_to_delete.py"
|
||||
line_number = 30
|
||||
snippet = "def function_to_delete():"
|
||||
description = "Function to delete"
|
||||
|
||||
key_snippet = repo.create(
|
||||
filepath=filepath,
|
||||
line_number=line_number,
|
||||
snippet=snippet,
|
||||
description=description
|
||||
)
|
||||
|
||||
# Verify the snippet exists
|
||||
assert KeySnippet.get_or_none(KeySnippet.id == key_snippet.id) is not None
|
||||
|
||||
# Delete the snippet
|
||||
delete_result = repo.delete(key_snippet.id)
|
||||
|
||||
# Verify the delete operation was successful
|
||||
assert delete_result is True
|
||||
|
||||
# Verify the snippet no longer exists in the database
|
||||
assert KeySnippet.get_or_none(KeySnippet.id == key_snippet.id) is None
|
||||
|
||||
# Try to delete a non-existent snippet
|
||||
non_existent_delete = repo.delete(999)
|
||||
assert non_existent_delete is False
|
||||
|
||||
|
||||
def test_get_all_key_snippets(setup_db):
|
||||
"""Test retrieving all key snippets."""
|
||||
# Set up repository with the in-memory database
|
||||
repo = KeySnippetRepository(db=setup_db)
|
||||
|
||||
# Create some key snippets
|
||||
snippets_data = [
|
||||
{
|
||||
"filepath": "file1.py",
|
||||
"line_number": 10,
|
||||
"snippet": "def function1():",
|
||||
"description": "Function 1"
|
||||
},
|
||||
{
|
||||
"filepath": "file2.py",
|
||||
"line_number": 20,
|
||||
"snippet": "def function2():",
|
||||
"description": "Function 2"
|
||||
},
|
||||
{
|
||||
"filepath": "file3.py",
|
||||
"line_number": 30,
|
||||
"snippet": "def function3():",
|
||||
"description": "Function 3"
|
||||
}
|
||||
]
|
||||
|
||||
for data in snippets_data:
|
||||
repo.create(**data)
|
||||
|
||||
# Retrieve all snippets
|
||||
all_snippets = repo.get_all()
|
||||
|
||||
# Verify we got the correct number of snippets
|
||||
assert len(all_snippets) == len(snippets_data)
|
||||
|
||||
# Verify the content of each snippet
|
||||
for i, snippet in enumerate(all_snippets):
|
||||
assert snippet.filepath == snippets_data[i]["filepath"]
|
||||
assert snippet.line_number == snippets_data[i]["line_number"]
|
||||
assert snippet.snippet == snippets_data[i]["snippet"]
|
||||
assert snippet.description == snippets_data[i]["description"]
|
||||
|
||||
|
||||
def test_get_snippets_dict(setup_db):
|
||||
"""Test retrieving key snippets as a dictionary."""
|
||||
# Set up repository with the in-memory database
|
||||
repo = KeySnippetRepository(db=setup_db)
|
||||
|
||||
# Create some key snippets
|
||||
snippets = []
|
||||
snippets_data = [
|
||||
{
|
||||
"filepath": "file1.py",
|
||||
"line_number": 10,
|
||||
"snippet": "def function1():",
|
||||
"description": "Function 1"
|
||||
},
|
||||
{
|
||||
"filepath": "file2.py",
|
||||
"line_number": 20,
|
||||
"snippet": "def function2():",
|
||||
"description": "Function 2"
|
||||
},
|
||||
{
|
||||
"filepath": "file3.py",
|
||||
"line_number": 30,
|
||||
"snippet": "def function3():",
|
||||
"description": "Function 3"
|
||||
}
|
||||
]
|
||||
|
||||
for data in snippets_data:
|
||||
snippets.append(repo.create(**data))
|
||||
|
||||
# Retrieve snippets as dictionary
|
||||
snippets_dict = repo.get_snippets_dict()
|
||||
|
||||
# Verify we got the correct number of snippets
|
||||
assert len(snippets_dict) == len(snippets_data)
|
||||
|
||||
# Verify each snippet is in the dictionary with the correct content
|
||||
for i, snippet in enumerate(snippets):
|
||||
assert snippet.id in snippets_dict
|
||||
assert snippets_dict[snippet.id]["filepath"] == snippets_data[i]["filepath"]
|
||||
assert snippets_dict[snippet.id]["line_number"] == snippets_data[i]["line_number"]
|
||||
assert snippets_dict[snippet.id]["snippet"] == snippets_data[i]["snippet"]
|
||||
assert snippets_dict[snippet.id]["description"] == snippets_data[i]["description"]
|
||||
|
|
@ -53,15 +53,16 @@ def setup_test_model(cleanup_db):
|
|||
"""Set up a test model for database tests."""
|
||||
# Initialize the database in memory
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Initialize the database proxy
|
||||
from ra_aid.database.models import initialize_database
|
||||
initialize_database()
|
||||
|
||||
# Define a test model class
|
||||
class TestModel(BaseModel):
|
||||
name = peewee.CharField(max_length=100)
|
||||
value = peewee.IntegerField(default=0)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
|
||||
# Create the test table in a transaction
|
||||
with db.atomic():
|
||||
db.create_tables([TestModel], safe=True)
|
||||
|
|
@ -78,15 +79,16 @@ def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
|
|||
"""Test ensure_tables_created with explicit models."""
|
||||
# Initialize the database in memory
|
||||
db = init_db(in_memory=True)
|
||||
|
||||
# Initialize the database proxy
|
||||
from ra_aid.database.models import initialize_database
|
||||
initialize_database()
|
||||
|
||||
# Define a test model that uses this database
|
||||
# Define a test model that uses the proxy database
|
||||
class TestModel(BaseModel):
|
||||
name = peewee.CharField(max_length=100)
|
||||
value = peewee.IntegerField(default=0)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
|
||||
# Call ensure_tables_created with explicit models
|
||||
ensure_tables_created([TestModel])
|
||||
|
||||
|
|
@ -99,9 +101,9 @@ def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
|
|||
assert count == 1
|
||||
|
||||
|
||||
@patch("ra_aid.database.utils.get_db")
|
||||
@patch("ra_aid.database.utils.initialize_database")
|
||||
def test_ensure_tables_created_database_error(
|
||||
mock_get_db, setup_test_model, cleanup_db, mock_logger
|
||||
mock_initialize_database, setup_test_model, cleanup_db, mock_logger
|
||||
):
|
||||
"""Test ensure_tables_created handles database errors."""
|
||||
# Get the TestModel class from the fixture
|
||||
|
|
@ -113,8 +115,8 @@ def test_ensure_tables_created_database_error(
|
|||
mock_db.atomic.return_value.__exit__.return_value = None
|
||||
mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error")
|
||||
|
||||
# Configure get_db to return our mock
|
||||
mock_get_db.return_value = mock_db
|
||||
# Configure initialize_database to return our mock
|
||||
mock_initialize_database.return_value = mock_db
|
||||
|
||||
# Call ensure_tables_created and expect an exception
|
||||
with pytest.raises(peewee.DatabaseError):
|
||||
|
|
|
|||
Loading…
Reference in New Issue