key snippets db
This commit is contained in:
parent
be2eb298a5
commit
038e7b886c
|
|
@ -611,6 +611,9 @@ def main():
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Run cleanup tasks before exiting database context
|
||||||
|
run_cleanup()
|
||||||
|
|
||||||
except (KeyboardInterrupt, AgentInterrupt):
|
except (KeyboardInterrupt, AgentInterrupt):
|
||||||
print()
|
print()
|
||||||
print(" 👋 Bye!")
|
print(" 👋 Bye!")
|
||||||
|
|
@ -618,5 +621,17 @@ def main():
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def run_cleanup():
|
||||||
|
"""Run cleanup tasks after main execution."""
|
||||||
|
try:
|
||||||
|
# Import the key facts cleaner agent
|
||||||
|
from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent
|
||||||
|
|
||||||
|
# Run the key facts garbage collection agent regardless of the number of facts
|
||||||
|
run_key_facts_gc_agent()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to run cleanup tasks: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
@ -13,7 +13,7 @@ from ra_aid.database.migrations import (
|
||||||
get_migration_status,
|
get_migration_status,
|
||||||
init_migrations,
|
init_migrations,
|
||||||
)
|
)
|
||||||
from ra_aid.database.models import BaseModel
|
from ra_aid.database.models import BaseModel, initialize_database
|
||||||
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
|
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -22,6 +22,7 @@ __all__ = [
|
||||||
"close_db",
|
"close_db",
|
||||||
"DatabaseManager",
|
"DatabaseManager",
|
||||||
"BaseModel",
|
"BaseModel",
|
||||||
|
"initialize_database",
|
||||||
"get_model_count",
|
"get_model_count",
|
||||||
"truncate_table",
|
"truncate_table",
|
||||||
"ensure_tables_created",
|
"ensure_tables_created",
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,9 @@ import peewee
|
||||||
|
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
|
# Import initialize_database after it's defined in models.py
|
||||||
|
# We need to do the import inside functions to avoid circular imports
|
||||||
|
|
||||||
# Create contextvar to hold the database connection
|
# Create contextvar to hold the database connection
|
||||||
db_var = contextvars.ContextVar("db", default=None)
|
db_var = contextvars.ContextVar("db", default=None)
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
@ -34,16 +37,23 @@ class DatabaseManager:
|
||||||
# Or with in-memory database:
|
# Or with in-memory database:
|
||||||
with DatabaseManager(in_memory=True) as db:
|
with DatabaseManager(in_memory=True) as db:
|
||||||
# Use in-memory database
|
# Use in-memory database
|
||||||
|
|
||||||
|
# Or with custom base directory:
|
||||||
|
with DatabaseManager(base_dir="/custom/path") as db:
|
||||||
|
# Use database in custom directory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_memory: bool = False):
|
def __init__(self, in_memory: bool = False, base_dir: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Initialize the DatabaseManager.
|
Initialize the DatabaseManager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_memory: Whether to use an in-memory database (default: False)
|
in_memory: Whether to use an in-memory database (default: False)
|
||||||
|
base_dir: Optional base directory to use instead of current working directory.
|
||||||
|
If None, uses os.getcwd() (default: None)
|
||||||
"""
|
"""
|
||||||
self.in_memory = in_memory
|
self.in_memory = in_memory
|
||||||
|
self.base_dir = base_dir
|
||||||
|
|
||||||
def __enter__(self) -> peewee.SqliteDatabase:
|
def __enter__(self) -> peewee.SqliteDatabase:
|
||||||
"""
|
"""
|
||||||
|
|
@ -52,7 +62,19 @@ class DatabaseManager:
|
||||||
Returns:
|
Returns:
|
||||||
peewee.SqliteDatabase: The initialized database connection
|
peewee.SqliteDatabase: The initialized database connection
|
||||||
"""
|
"""
|
||||||
return init_db(in_memory=self.in_memory)
|
db = init_db(in_memory=self.in_memory, base_dir=self.base_dir)
|
||||||
|
|
||||||
|
# Initialize the database proxy in models.py
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from ra_aid.database.models import initialize_database
|
||||||
|
initialize_database()
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import initialize_database: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing database proxy: {str(e)}")
|
||||||
|
|
||||||
|
return db
|
||||||
|
|
||||||
def __exit__(
|
def __exit__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -74,7 +96,7 @@ class DatabaseManager:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
def init_db(in_memory: bool = False, base_dir: Optional[str] = None) -> peewee.SqliteDatabase:
|
||||||
"""
|
"""
|
||||||
Initialize the database connection.
|
Initialize the database connection.
|
||||||
|
|
||||||
|
|
@ -84,6 +106,8 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_memory: Whether to use an in-memory database (default: False)
|
in_memory: Whether to use an in-memory database (default: False)
|
||||||
|
base_dir: Optional base directory to use instead of current working directory.
|
||||||
|
If None, uses os.getcwd() (default: None)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
peewee.SqliteDatabase: The initialized database connection
|
peewee.SqliteDatabase: The initialized database connection
|
||||||
|
|
@ -110,9 +134,9 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||||
db_path = ":memory:"
|
db_path = ":memory:"
|
||||||
logger.debug("Using in-memory SQLite database")
|
logger.debug("Using in-memory SQLite database")
|
||||||
else:
|
else:
|
||||||
# Get current working directory and create .ra-aid directory if it doesn't exist
|
# Get base directory (use current working directory if not provided)
|
||||||
cwd = os.getcwd()
|
cwd = base_dir if base_dir is not None else os.getcwd()
|
||||||
logger.debug(f"Current working directory: {cwd}")
|
logger.debug(f"Base directory for database: {cwd}")
|
||||||
|
|
||||||
# Define the .ra-aid directory path
|
# Define the .ra-aid directory path
|
||||||
ra_aid_dir_str = os.path.join(cwd, ".ra-aid")
|
ra_aid_dir_str = os.path.join(cwd, ".ra-aid")
|
||||||
|
|
@ -300,13 +324,17 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_db() -> peewee.SqliteDatabase:
|
def get_db(base_dir: Optional[str] = None) -> peewee.SqliteDatabase:
|
||||||
"""
|
"""
|
||||||
Get the current database connection.
|
Get the current database connection.
|
||||||
|
|
||||||
If no connection exists, initializes a new one.
|
If no connection exists, initializes a new one.
|
||||||
If connection exists but is closed, reopens it.
|
If connection exists but is closed, reopens it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_dir: Optional base directory to use instead of current working directory.
|
||||||
|
If None, uses os.getcwd() (default: None)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
peewee.SqliteDatabase: The current database connection
|
peewee.SqliteDatabase: The current database connection
|
||||||
"""
|
"""
|
||||||
|
|
@ -315,7 +343,7 @@ def get_db() -> peewee.SqliteDatabase:
|
||||||
if db is None:
|
if db is None:
|
||||||
# No database connection exists, initialize one
|
# No database connection exists, initialize one
|
||||||
# Use the default in-memory mode (False)
|
# Use the default in-memory mode (False)
|
||||||
return init_db(in_memory=False)
|
return init_db(in_memory=False, base_dir=base_dir)
|
||||||
|
|
||||||
# Check if connection is closed and reopen if needed
|
# Check if connection is closed and reopen if needed
|
||||||
if db.is_closed():
|
if db.is_closed():
|
||||||
|
|
@ -332,7 +360,7 @@ def get_db() -> peewee.SqliteDatabase:
|
||||||
in_memory = hasattr(db, "_is_in_memory") and db._is_in_memory
|
in_memory = hasattr(db, "_is_in_memory") and db._is_in_memory
|
||||||
logger.debug(f"Creating new database connection (in_memory={in_memory})")
|
logger.debug(f"Creating new database connection (in_memory={in_memory})")
|
||||||
# Create a completely new database object, don't reuse the old one
|
# Create a completely new database object, don't reuse the old one
|
||||||
return init_db(in_memory=in_memory)
|
return init_db(in_memory=in_memory, base_dir=base_dir)
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,21 @@ from ra_aid.logging_config import get_logger
|
||||||
T = TypeVar("T", bound="BaseModel")
|
T = TypeVar("T", bound="BaseModel")
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Create a database proxy that will be initialized later
|
||||||
|
database_proxy = peewee.DatabaseProxy()
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_database():
|
||||||
|
"""
|
||||||
|
Initialize the database proxy with a real database connection.
|
||||||
|
|
||||||
|
This function should be called before any database operations
|
||||||
|
to ensure the proxy points to a real database connection.
|
||||||
|
"""
|
||||||
|
db = get_db()
|
||||||
|
database_proxy.initialize(db)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(peewee.Model):
|
class BaseModel(peewee.Model):
|
||||||
"""
|
"""
|
||||||
|
|
@ -28,7 +43,7 @@ class BaseModel(peewee.Model):
|
||||||
updated_at = peewee.DateTimeField(default=datetime.datetime.now)
|
updated_at = peewee.DateTimeField(default=datetime.datetime.now)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
database = get_db()
|
database = database_proxy
|
||||||
|
|
||||||
def save(self, *args: Any, **kwargs: Any) -> int:
|
def save(self, *args: Any, **kwargs: Any) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
@ -76,3 +91,21 @@ class KeyFact(BaseModel):
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "key_fact"
|
table_name = "key_fact"
|
||||||
|
|
||||||
|
|
||||||
|
class KeySnippet(BaseModel):
|
||||||
|
"""
|
||||||
|
Model representing a key code snippet stored in the database.
|
||||||
|
|
||||||
|
Key snippets are important code fragments from the project that need to be
|
||||||
|
referenced later. Each snippet includes its file location, line number,
|
||||||
|
the code content itself, and an optional description of its significance.
|
||||||
|
"""
|
||||||
|
filepath = peewee.TextField()
|
||||||
|
line_number = peewee.IntegerField()
|
||||||
|
snippet = peewee.TextField()
|
||||||
|
description = peewee.TextField(null=True)
|
||||||
|
# created_at and updated_at are inherited from BaseModel
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = "key_snippet"
|
||||||
|
|
@ -10,7 +10,7 @@ from typing import Dict, List, Optional
|
||||||
import peewee
|
import peewee
|
||||||
|
|
||||||
from ra_aid.database.connection import get_db
|
from ra_aid.database.connection import get_db
|
||||||
from ra_aid.database.models import KeyFact
|
from ra_aid.database.models import KeyFact, initialize_database
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
@ -29,6 +29,15 @@ class KeyFactRepository:
|
||||||
all_facts = repo.get_all()
|
all_facts = repo.get_all()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db=None):
|
||||||
|
"""
|
||||||
|
Initialize the repository with an optional database connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Optional database connection to use. If None, will use initialize_database()
|
||||||
|
"""
|
||||||
|
self.db = db
|
||||||
|
|
||||||
def create(self, content: str) -> KeyFact:
|
def create(self, content: str) -> KeyFact:
|
||||||
"""
|
"""
|
||||||
Create a new key fact in the database.
|
Create a new key fact in the database.
|
||||||
|
|
@ -43,7 +52,7 @@ class KeyFactRepository:
|
||||||
peewee.DatabaseError: If there's an error creating the fact
|
peewee.DatabaseError: If there's an error creating the fact
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = get_db()
|
db = self.db if self.db is not None else initialize_database()
|
||||||
fact = KeyFact.create(content=content)
|
fact = KeyFact.create(content=content)
|
||||||
logger.debug(f"Created key fact ID {fact.id}: {content}")
|
logger.debug(f"Created key fact ID {fact.id}: {content}")
|
||||||
return fact
|
return fact
|
||||||
|
|
@ -65,7 +74,7 @@ class KeyFactRepository:
|
||||||
peewee.DatabaseError: If there's an error accessing the database
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = get_db()
|
db = self.db if self.db is not None else initialize_database()
|
||||||
return KeyFact.get_or_none(KeyFact.id == fact_id)
|
return KeyFact.get_or_none(KeyFact.id == fact_id)
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
|
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
|
||||||
|
|
@ -86,7 +95,7 @@ class KeyFactRepository:
|
||||||
peewee.DatabaseError: If there's an error updating the fact
|
peewee.DatabaseError: If there's an error updating the fact
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = get_db()
|
db = self.db if self.db is not None else initialize_database()
|
||||||
# First check if the fact exists
|
# First check if the fact exists
|
||||||
fact = self.get(fact_id)
|
fact = self.get(fact_id)
|
||||||
if not fact:
|
if not fact:
|
||||||
|
|
@ -116,7 +125,7 @@ class KeyFactRepository:
|
||||||
peewee.DatabaseError: If there's an error deleting the fact
|
peewee.DatabaseError: If there's an error deleting the fact
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = get_db()
|
db = self.db if self.db is not None else initialize_database()
|
||||||
# First check if the fact exists
|
# First check if the fact exists
|
||||||
fact = self.get(fact_id)
|
fact = self.get(fact_id)
|
||||||
if not fact:
|
if not fact:
|
||||||
|
|
@ -142,7 +151,7 @@ class KeyFactRepository:
|
||||||
peewee.DatabaseError: If there's an error accessing the database
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = get_db()
|
db = self.db if self.db is not None else initialize_database()
|
||||||
return list(KeyFact.select().order_by(KeyFact.id))
|
return list(KeyFact.select().order_by(KeyFact.id))
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Failed to fetch all key facts: {str(e)}")
|
logger.error(f"Failed to fetch all key facts: {str(e)}")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,214 @@
|
||||||
|
"""
|
||||||
|
Key snippet repository implementation for database access.
|
||||||
|
|
||||||
|
This module provides a repository implementation for the KeySnippet model,
|
||||||
|
following the repository pattern for data access abstraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
|
||||||
|
import peewee
|
||||||
|
|
||||||
|
from ra_aid.database.connection import get_db
|
||||||
|
from ra_aid.database.models import KeySnippet, initialize_database
|
||||||
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KeySnippetRepository:
|
||||||
|
"""
|
||||||
|
Repository for managing KeySnippet database operations.
|
||||||
|
|
||||||
|
This class provides methods for performing CRUD operations on the KeySnippet model,
|
||||||
|
abstracting the database access details from the business logic.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
repo = KeySnippetRepository()
|
||||||
|
snippet = repo.create(
|
||||||
|
filepath="main.py",
|
||||||
|
line_number=42,
|
||||||
|
snippet="def hello_world():",
|
||||||
|
description="Main function definition"
|
||||||
|
)
|
||||||
|
all_snippets = repo.get_all()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db=None):
|
||||||
|
"""
|
||||||
|
Initialize the repository with an optional database connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Optional database connection to use. If None, will use initialize_database()
|
||||||
|
"""
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self, filepath: str, line_number: int, snippet: str, description: Optional[str] = None
|
||||||
|
) -> KeySnippet:
|
||||||
|
"""
|
||||||
|
Create a new key snippet in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath: Path to the source file
|
||||||
|
line_number: Line number where the snippet starts
|
||||||
|
snippet: The source code snippet text
|
||||||
|
description: Optional description of the significance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KeySnippet: The newly created key snippet instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error creating the snippet
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db = self.db if self.db is not None else initialize_database()
|
||||||
|
key_snippet = KeySnippet.create(
|
||||||
|
filepath=filepath,
|
||||||
|
line_number=line_number,
|
||||||
|
snippet=snippet,
|
||||||
|
description=description
|
||||||
|
)
|
||||||
|
logger.debug(f"Created key snippet ID {key_snippet.id}: {filepath}:{line_number}")
|
||||||
|
return key_snippet
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to create key snippet: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get(self, snippet_id: int) -> Optional[KeySnippet]:
|
||||||
|
"""
|
||||||
|
Retrieve a key snippet by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snippet_id: The ID of the key snippet to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[KeySnippet]: The key snippet instance if found, None otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db = self.db if self.db is not None else initialize_database()
|
||||||
|
return KeySnippet.get_or_none(KeySnippet.id == snippet_id)
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
snippet_id: int,
|
||||||
|
filepath: str,
|
||||||
|
line_number: int,
|
||||||
|
snippet: str,
|
||||||
|
description: Optional[str] = None
|
||||||
|
) -> Optional[KeySnippet]:
|
||||||
|
"""
|
||||||
|
Update an existing key snippet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snippet_id: The ID of the key snippet to update
|
||||||
|
filepath: Path to the source file
|
||||||
|
line_number: Line number where the snippet starts
|
||||||
|
snippet: The source code snippet text
|
||||||
|
description: Optional description of the significance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[KeySnippet]: The updated key snippet if found, None otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error updating the snippet
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db = self.db if self.db is not None else initialize_database()
|
||||||
|
# First check if the snippet exists
|
||||||
|
key_snippet = self.get(snippet_id)
|
||||||
|
if not key_snippet:
|
||||||
|
logger.warning(f"Attempted to update non-existent key snippet {snippet_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update the snippet
|
||||||
|
key_snippet.filepath = filepath
|
||||||
|
key_snippet.line_number = line_number
|
||||||
|
key_snippet.snippet = snippet
|
||||||
|
key_snippet.description = description
|
||||||
|
key_snippet.save()
|
||||||
|
logger.debug(f"Updated key snippet ID {snippet_id}: {filepath}:{line_number}")
|
||||||
|
return key_snippet
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to update key snippet {snippet_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete(self, snippet_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a key snippet by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snippet_id: The ID of the key snippet to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the snippet was deleted, False if it wasn't found
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error deleting the snippet
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db = self.db if self.db is not None else initialize_database()
|
||||||
|
# First check if the snippet exists
|
||||||
|
key_snippet = self.get(snippet_id)
|
||||||
|
if not key_snippet:
|
||||||
|
logger.warning(f"Attempted to delete non-existent key snippet {snippet_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Delete the snippet
|
||||||
|
key_snippet.delete_instance()
|
||||||
|
logger.debug(f"Deleted key snippet ID {snippet_id}")
|
||||||
|
return True
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to delete key snippet {snippet_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_all(self) -> List[KeySnippet]:
|
||||||
|
"""
|
||||||
|
Retrieve all key snippets from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[KeySnippet]: List of all key snippet instances
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db = self.db if self.db is not None else initialize_database()
|
||||||
|
return list(KeySnippet.select().order_by(KeySnippet.id))
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to fetch all key snippets: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_snippets_dict(self) -> Dict[int, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Retrieve all key snippets as a dictionary mapping IDs to snippet information.
|
||||||
|
|
||||||
|
This method is useful for compatibility with the existing memory format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[int, Dict[str, Any]]: Dictionary with snippet IDs as keys and
|
||||||
|
snippet information as values
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error accessing the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
snippets = self.get_all()
|
||||||
|
return {
|
||||||
|
snippet.id: {
|
||||||
|
"filepath": snippet.filepath,
|
||||||
|
"line_number": snippet.line_number,
|
||||||
|
"snippet": snippet.snippet,
|
||||||
|
"description": snippet.description
|
||||||
|
}
|
||||||
|
for snippet in snippets
|
||||||
|
}
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
"""
|
"""
|
||||||
Tests for the database connection module.
|
Tests for the database connection module.
|
||||||
|
|
||||||
|
This file tests the database connection functionality using pytest's fixtures
|
||||||
|
for proper test isolation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -21,7 +24,10 @@ from ra_aid.database.connection import (
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def cleanup_db():
|
def cleanup_db():
|
||||||
"""
|
"""
|
||||||
Fixture to clean up database connections after tests.
|
Fixture to clean up database connections between tests.
|
||||||
|
|
||||||
|
This ensures that we don't leak database connections between tests
|
||||||
|
and that the db_var contextvar is reset.
|
||||||
"""
|
"""
|
||||||
# Run the test
|
# Run the test
|
||||||
yield
|
yield
|
||||||
|
|
@ -29,10 +35,13 @@ def cleanup_db():
|
||||||
# Clean up after the test
|
# Clean up after the test
|
||||||
db = db_var.get()
|
db = db_var.get()
|
||||||
if db is not None:
|
if db is not None:
|
||||||
|
# Clean up attributes we may have added
|
||||||
if hasattr(db, "_is_in_memory"):
|
if hasattr(db, "_is_in_memory"):
|
||||||
delattr(db, "_is_in_memory")
|
delattr(db, "_is_in_memory")
|
||||||
if hasattr(db, "_message_shown"):
|
if hasattr(db, "_message_shown"):
|
||||||
delattr(db, "_message_shown")
|
delattr(db, "_message_shown")
|
||||||
|
|
||||||
|
# Close the connection if it's open
|
||||||
if not db.is_closed():
|
if not db.is_closed():
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
@ -41,348 +50,313 @@ def cleanup_db():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def setup_in_memory_db():
|
def db_path_mock(tmp_path, monkeypatch):
|
||||||
"""
|
"""
|
||||||
Fixture to set up an in-memory database for testing.
|
Fixture to mock os.getcwd() to return a temporary directory path.
|
||||||
|
|
||||||
|
This ensures that all database operations use the temporary directory
|
||||||
|
and never touch the actual current working directory.
|
||||||
"""
|
"""
|
||||||
# Initialize in-memory database
|
|
||||||
db = init_db(in_memory=True)
|
|
||||||
|
|
||||||
# Run the test
|
|
||||||
yield db
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if not db.is_closed():
|
|
||||||
db.close()
|
|
||||||
db_var.set(None)
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_db_creates_directory(cleanup_db, tmp_path):
|
|
||||||
"""
|
|
||||||
Test that init_db creates the .ra-aid directory if it doesn't exist.
|
|
||||||
"""
|
|
||||||
# Get and print the original working directory
|
|
||||||
original_cwd = os.getcwd()
|
original_cwd = os.getcwd()
|
||||||
print(f"Original working directory: {original_cwd}")
|
|
||||||
|
|
||||||
# Convert tmp_path to string for consistent handling
|
|
||||||
tmp_path_str = str(tmp_path.absolute())
|
tmp_path_str = str(tmp_path.absolute())
|
||||||
print(f"Temporary directory path: {tmp_path_str}")
|
|
||||||
|
|
||||||
# Change to the temporary directory
|
# Create the .ra-aid directory in the temporary path
|
||||||
os.chdir(tmp_path_str)
|
ra_aid_dir = tmp_path / ".ra-aid"
|
||||||
current_cwd = os.getcwd()
|
ra_aid_dir.mkdir(exist_ok=True)
|
||||||
print(f"Changed working directory to: {current_cwd}")
|
|
||||||
assert (
|
|
||||||
current_cwd == tmp_path_str
|
|
||||||
), f"Failed to change directory: {current_cwd} != {tmp_path_str}"
|
|
||||||
|
|
||||||
# Create the .ra-aid directory manually to ensure it exists
|
# Mock os.getcwd() to return the temporary directory
|
||||||
ra_aid_path_str = os.path.join(current_cwd, ".ra-aid")
|
monkeypatch.setattr(os, "getcwd", lambda: tmp_path_str)
|
||||||
print(f"Creating .ra-aid directory at: {ra_aid_path_str}")
|
|
||||||
os.makedirs(ra_aid_path_str, exist_ok=True)
|
|
||||||
|
|
||||||
# Verify the directory was created
|
yield tmp_path
|
||||||
assert os.path.exists(
|
|
||||||
ra_aid_path_str
|
|
||||||
), f".ra-aid directory not found at {ra_aid_path_str}"
|
|
||||||
assert os.path.isdir(
|
|
||||||
ra_aid_path_str
|
|
||||||
), f"{ra_aid_path_str} exists but is not a directory"
|
|
||||||
|
|
||||||
# Create a test file to verify write permissions
|
# Ensure we're back to the original directory after the test
|
||||||
test_file_path = os.path.join(ra_aid_path_str, "test_write.txt")
|
os.chdir(original_cwd)
|
||||||
print(f"Creating test file to verify write permissions: {test_file_path}")
|
|
||||||
with open(test_file_path, "w") as f:
|
|
||||||
f.write("Test write permissions")
|
|
||||||
|
|
||||||
# Verify the test file was created
|
|
||||||
assert os.path.exists(test_file_path), f"Test file not created at {test_file_path}"
|
|
||||||
|
|
||||||
# Create an empty database file to ensure it exists before init_db
|
class TestInitDb:
|
||||||
db_file_str = os.path.join(ra_aid_path_str, "pk.db")
|
"""Tests for the init_db function."""
|
||||||
print(f"Creating empty database file at: {db_file_str}")
|
|
||||||
with open(db_file_str, "w") as f:
|
|
||||||
f.write("") # Create empty file
|
|
||||||
|
|
||||||
# Verify the database file was created
|
def test_init_db_in_memory(self, cleanup_db):
|
||||||
assert os.path.exists(
|
"""Test init_db with in_memory=True."""
|
||||||
db_file_str
|
# Reset the contextvar to ensure a fresh start
|
||||||
), f"Empty database file not created at {db_file_str}"
|
db_var.set(None)
|
||||||
print(f"Empty database file size: {os.path.getsize(db_file_str)} bytes")
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
# Get directory permissions for debugging
|
|
||||||
dir_perms = oct(os.stat(ra_aid_path_str).st_mode)[-3:]
|
|
||||||
print(f"Directory permissions: {dir_perms}")
|
|
||||||
|
|
||||||
# Initialize the database
|
|
||||||
print("Calling init_db()")
|
|
||||||
db = init_db()
|
|
||||||
print("init_db() returned successfully")
|
|
||||||
|
|
||||||
# List contents of the current directory for debugging
|
|
||||||
print(f"Contents of current directory: {os.listdir(current_cwd)}")
|
|
||||||
|
|
||||||
# List contents of the .ra-aid directory for debugging
|
|
||||||
print(f"Contents of .ra-aid directory: {os.listdir(ra_aid_path_str)}")
|
|
||||||
|
|
||||||
# Check that the database file exists using os.path
|
|
||||||
assert os.path.exists(db_file_str), f"Database file not found at {db_file_str}"
|
|
||||||
assert os.path.isfile(db_file_str), f"{db_file_str} exists but is not a file"
|
|
||||||
print(f"Final database file size: {os.path.getsize(db_file_str)} bytes")
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_db_creates_database_file(cleanup_db, tmp_path):
|
|
||||||
"""
|
|
||||||
Test that init_db creates the database file.
|
|
||||||
"""
|
|
||||||
# Change to the temporary directory
|
|
||||||
os.chdir(tmp_path)
|
|
||||||
|
|
||||||
# Initialize the database
|
|
||||||
init_db()
|
|
||||||
|
|
||||||
# Check that the database file was created
|
|
||||||
assert (tmp_path / ".ra-aid" / "pk.db").exists()
|
|
||||||
assert (tmp_path / ".ra-aid" / "pk.db").is_file()
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_db_returns_database_connection(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that init_db returns a database connection.
|
|
||||||
"""
|
|
||||||
# Initialize the database
|
|
||||||
db = init_db()
|
|
||||||
|
|
||||||
# Check that the database connection is returned
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
|
||||||
assert not db.is_closed()
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_db_with_in_memory_mode(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that init_db with in_memory=True creates an in-memory database.
|
|
||||||
"""
|
|
||||||
# Initialize the database in in-memory mode
|
|
||||||
db = init_db(in_memory=True)
|
|
||||||
|
|
||||||
# Check that the database connection is returned
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
|
||||||
assert not db.is_closed()
|
|
||||||
assert hasattr(db, "_is_in_memory")
|
|
||||||
assert db._is_in_memory is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_in_memory_mode_no_directory_created(cleanup_db, tmp_path):
|
|
||||||
"""
|
|
||||||
Test that when using in-memory mode, no directory is created.
|
|
||||||
"""
|
|
||||||
# Change to the temporary directory
|
|
||||||
os.chdir(tmp_path)
|
|
||||||
|
|
||||||
# Initialize the database in in-memory mode
|
|
||||||
init_db(in_memory=True)
|
|
||||||
|
|
||||||
# Check that the .ra-aid directory was not created
|
|
||||||
# (Note: it might be created by other tests, so we can't assert it doesn't exist)
|
|
||||||
# Instead, check that the database file was not created
|
|
||||||
assert not (tmp_path / ".ra-aid" / "pk.db").exists()
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_db_returns_existing_connection(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that init_db returns the existing connection if one exists.
|
|
||||||
"""
|
|
||||||
# Initialize the database
|
|
||||||
db1 = init_db()
|
|
||||||
|
|
||||||
# Initialize the database again
|
|
||||||
db2 = init_db()
|
|
||||||
|
|
||||||
# Check that the same connection is returned
|
|
||||||
assert db1 is db2
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_db_reopens_closed_connection(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that init_db reopens a closed connection.
|
|
||||||
"""
|
|
||||||
# Initialize the database
|
|
||||||
db1 = init_db()
|
|
||||||
|
|
||||||
# Close the connection
|
|
||||||
db1.close()
|
|
||||||
|
|
||||||
# Initialize the database again
|
|
||||||
db2 = init_db()
|
|
||||||
|
|
||||||
# Check that the same connection is returned and it's open
|
|
||||||
assert db1 is db2
|
|
||||||
assert not db1.is_closed()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_db_initializes_connection(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that get_db initializes a connection if none exists.
|
|
||||||
"""
|
|
||||||
# Get the database connection
|
|
||||||
db = get_db()
|
|
||||||
|
|
||||||
# Check that a connection was initialized
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
|
||||||
assert not db.is_closed()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_db_returns_existing_connection(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that get_db returns the existing connection if one exists.
|
|
||||||
"""
|
|
||||||
# Initialize the database
|
|
||||||
db1 = init_db()
|
|
||||||
|
|
||||||
# Get the database connection
|
|
||||||
db2 = get_db()
|
|
||||||
|
|
||||||
# Check that the same connection is returned
|
|
||||||
assert db1 is db2
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_db_reopens_closed_connection(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that get_db reopens a closed connection.
|
|
||||||
"""
|
|
||||||
# Initialize the database
|
|
||||||
db = init_db()
|
|
||||||
|
|
||||||
# Close the connection
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
# Get the database connection
|
|
||||||
db2 = get_db()
|
|
||||||
|
|
||||||
# Check that the same connection is returned and it's open
|
|
||||||
assert db is db2
|
|
||||||
assert not db.is_closed()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_db_handles_reopen_error(cleanup_db, monkeypatch):
|
|
||||||
"""
|
|
||||||
Test that get_db handles errors when reopening a connection.
|
|
||||||
"""
|
|
||||||
# Initialize the database
|
|
||||||
db = init_db()
|
|
||||||
|
|
||||||
# Close the connection
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
# Create a patched version of the connect method that raises an error
|
|
||||||
original_connect = peewee.SqliteDatabase.connect
|
|
||||||
|
|
||||||
def mock_connect(self, *args, **kwargs):
|
|
||||||
if self is db: # Only raise for the specific db instance
|
|
||||||
raise peewee.OperationalError("Test error")
|
|
||||||
return original_connect(self, *args, **kwargs)
|
|
||||||
|
|
||||||
# Apply the patch
|
|
||||||
monkeypatch.setattr(peewee.SqliteDatabase, "connect", mock_connect)
|
|
||||||
|
|
||||||
# Get the database connection
|
|
||||||
db2 = get_db()
|
|
||||||
|
|
||||||
# Check that a new connection was initialized
|
|
||||||
assert db is not db2
|
|
||||||
assert not db2.is_closed()
|
|
||||||
|
|
||||||
|
|
||||||
def test_close_db_closes_connection(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that close_db closes the connection.
|
|
||||||
"""
|
|
||||||
# Initialize the database
|
|
||||||
db = init_db()
|
|
||||||
|
|
||||||
# Close the connection
|
|
||||||
close_db()
|
|
||||||
|
|
||||||
# Check that the connection is closed
|
|
||||||
assert db.is_closed()
|
|
||||||
|
|
||||||
|
|
||||||
def test_close_db_handles_no_connection():
|
|
||||||
"""
|
|
||||||
Test that close_db handles the case where no connection exists.
|
|
||||||
"""
|
|
||||||
# Reset the contextvar
|
|
||||||
db_var.set(None)
|
|
||||||
|
|
||||||
# Close the connection (should not raise an error)
|
|
||||||
close_db()
|
|
||||||
|
|
||||||
|
|
||||||
def test_close_db_handles_already_closed_connection(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that close_db handles the case where the connection is already closed.
|
|
||||||
"""
|
|
||||||
# Initialize the database
|
|
||||||
db = init_db()
|
|
||||||
|
|
||||||
# Close the connection
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
# Close the connection again (should not raise an error)
|
|
||||||
close_db()
|
|
||||||
|
|
||||||
|
|
||||||
@patch("ra_aid.database.connection.peewee.SqliteDatabase.close")
|
|
||||||
def test_close_db_handles_error(mock_close, cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that close_db handles errors when closing the connection.
|
|
||||||
"""
|
|
||||||
# Initialize the database
|
|
||||||
init_db()
|
|
||||||
|
|
||||||
# Make close raise an error
|
|
||||||
mock_close.side_effect = peewee.DatabaseError("Test error")
|
|
||||||
|
|
||||||
# Close the connection (should not raise an error)
|
|
||||||
close_db()
|
|
||||||
|
|
||||||
|
|
||||||
def test_database_manager_context_manager(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that DatabaseManager works as a context manager.
|
|
||||||
"""
|
|
||||||
# Use the context manager
|
|
||||||
with DatabaseManager() as db:
|
|
||||||
# Check that a connection was initialized
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
|
||||||
assert not db.is_closed()
|
|
||||||
|
|
||||||
# Store the connection for later
|
|
||||||
db_in_context = db
|
|
||||||
|
|
||||||
# Check that the connection is closed after exiting the context
|
|
||||||
assert db_in_context.is_closed()
|
|
||||||
|
|
||||||
|
|
||||||
def test_database_manager_with_in_memory_mode(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that DatabaseManager with in_memory=True creates an in-memory database.
|
|
||||||
"""
|
|
||||||
# Use the context manager with in_memory=True
|
|
||||||
with DatabaseManager(in_memory=True) as db:
|
|
||||||
# Check that a connection was initialized
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, "_is_in_memory")
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is True
|
assert db._is_in_memory is True
|
||||||
|
|
||||||
|
def test_init_db_creates_directory(self, cleanup_db, db_path_mock):
|
||||||
|
"""Test that init_db creates the .ra-aid directory if it doesn't exist."""
|
||||||
|
# Remove the .ra-aid directory to test creation
|
||||||
|
ra_aid_dir = db_path_mock / ".ra-aid"
|
||||||
|
if ra_aid_dir.exists():
|
||||||
|
for item in ra_aid_dir.iterdir():
|
||||||
|
if item.is_file():
|
||||||
|
item.unlink()
|
||||||
|
ra_aid_dir.rmdir()
|
||||||
|
|
||||||
|
# Initialize the database
|
||||||
|
db = init_db()
|
||||||
|
|
||||||
|
# Check that the directory was created
|
||||||
|
assert ra_aid_dir.exists()
|
||||||
|
assert ra_aid_dir.is_dir()
|
||||||
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
|
assert not db.is_closed()
|
||||||
|
assert hasattr(db, "_is_in_memory")
|
||||||
|
assert db._is_in_memory is False
|
||||||
|
|
||||||
|
def test_init_db_creates_database_file(self, cleanup_db, db_path_mock):
|
||||||
|
"""Test that init_db creates the database file."""
|
||||||
|
# Initialize the database
|
||||||
|
init_db()
|
||||||
|
|
||||||
|
# Check that the database file was created
|
||||||
|
assert (db_path_mock / ".ra-aid" / "pk.db").exists()
|
||||||
|
assert (db_path_mock / ".ra-aid" / "pk.db").is_file()
|
||||||
|
|
||||||
|
def test_init_db_reuses_connection(self, cleanup_db):
|
||||||
|
"""Test that init_db reuses an existing connection."""
|
||||||
|
# Reset the contextvar to ensure a fresh start
|
||||||
|
db_var.set(None)
|
||||||
|
|
||||||
|
# Use in_memory=True for this test to avoid touching the filesystem
|
||||||
|
db1 = init_db(in_memory=True)
|
||||||
|
db2 = init_db(in_memory=True)
|
||||||
|
|
||||||
|
assert db1 is db2
|
||||||
|
|
||||||
|
def test_init_db_reopens_closed_connection(self, cleanup_db):
|
||||||
|
"""Test that init_db reopens a closed connection."""
|
||||||
|
# Reset the contextvar to ensure a fresh start
|
||||||
|
db_var.set(None)
|
||||||
|
|
||||||
|
# Use in_memory=True for this test to avoid touching the filesystem
|
||||||
|
db1 = init_db(in_memory=True)
|
||||||
|
db1.close()
|
||||||
|
assert db1.is_closed()
|
||||||
|
|
||||||
|
db2 = init_db(in_memory=True)
|
||||||
|
assert db1 is db2
|
||||||
|
assert not db1.is_closed()
|
||||||
|
|
||||||
|
def test_in_memory_mode_no_directory_created(self, cleanup_db, db_path_mock):
|
||||||
|
"""Test that when using in_memory mode, no database file is created."""
|
||||||
|
# Initialize the database in in-memory mode
|
||||||
|
init_db(in_memory=True)
|
||||||
|
|
||||||
|
# Check that the database file was not created
|
||||||
|
assert not (db_path_mock / ".ra-aid" / "pk.db").exists()
|
||||||
|
|
||||||
|
def test_init_db_sets_is_in_memory_attribute(self, cleanup_db):
|
||||||
|
"""Test that init_db sets the _is_in_memory attribute."""
|
||||||
|
# Test with in_memory=True
|
||||||
|
db = init_db(in_memory=True)
|
||||||
|
assert hasattr(db, "_is_in_memory")
|
||||||
|
assert db._is_in_memory is True
|
||||||
|
|
||||||
|
# Reset the contextvar
|
||||||
|
db_var.set(None)
|
||||||
|
|
||||||
|
# Test with in_memory=False, but use a mocked directory
|
||||||
|
with patch("os.getcwd") as mock_getcwd:
|
||||||
|
temp_dir = Path("/tmp/testdb")
|
||||||
|
mock_getcwd.return_value = str(temp_dir)
|
||||||
|
|
||||||
|
# Mock os.path.exists and os.makedirs to avoid filesystem operations
|
||||||
|
with patch("os.path.exists", return_value=True):
|
||||||
|
with patch("os.makedirs"):
|
||||||
|
with patch("os.path.isdir", return_value=True):
|
||||||
|
with patch.object(peewee.SqliteDatabase, "connect"):
|
||||||
|
with patch.object(peewee.SqliteDatabase, "execute_sql"):
|
||||||
|
db = init_db(in_memory=False)
|
||||||
|
assert hasattr(db, "_is_in_memory")
|
||||||
|
assert db._is_in_memory is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDb:
|
||||||
|
"""Tests for the get_db function."""
|
||||||
|
|
||||||
|
def test_get_db_initializes_connection(self, cleanup_db):
|
||||||
|
"""Test that get_db initializes a connection if none exists."""
|
||||||
|
# Reset the contextvar to ensure no connection exists
|
||||||
|
db_var.set(None)
|
||||||
|
|
||||||
|
# Use a patch to avoid touching the filesystem
|
||||||
|
with patch("ra_aid.database.connection.init_db") as mock_init_db:
|
||||||
|
mock_db = peewee.SqliteDatabase(":memory:")
|
||||||
|
mock_db._is_in_memory = False
|
||||||
|
mock_init_db.return_value = mock_db
|
||||||
|
|
||||||
|
db = get_db()
|
||||||
|
|
||||||
|
mock_init_db.assert_called_once_with(in_memory=False, base_dir=None)
|
||||||
|
assert db is mock_db
|
||||||
|
|
||||||
|
def test_get_db_returns_existing_connection(self, cleanup_db):
|
||||||
|
"""Test that get_db returns the existing connection if one exists."""
|
||||||
|
# Reset the contextvar to ensure a fresh start
|
||||||
|
db_var.set(None)
|
||||||
|
|
||||||
|
# Use in_memory=True for this test to avoid touching the filesystem
|
||||||
|
db1 = init_db(in_memory=True)
|
||||||
|
db2 = get_db()
|
||||||
|
|
||||||
|
assert db1 is db2
|
||||||
|
|
||||||
|
def test_get_db_reopens_closed_connection(self, cleanup_db):
|
||||||
|
"""Test that get_db reopens a closed connection."""
|
||||||
|
# Reset the contextvar to ensure a fresh start
|
||||||
|
db_var.set(None)
|
||||||
|
|
||||||
|
# Use in_memory=True for this test to avoid touching the filesystem
|
||||||
|
db1 = init_db(in_memory=True)
|
||||||
|
db1.close()
|
||||||
|
assert db1.is_closed()
|
||||||
|
|
||||||
|
db2 = get_db()
|
||||||
|
assert db1 is db2
|
||||||
|
assert not db1.is_closed()
|
||||||
|
|
||||||
|
def test_get_db_handles_reopen_error(self, cleanup_db, monkeypatch):
|
||||||
|
"""Test that get_db handles errors when reopening a connection."""
|
||||||
|
# Reset the contextvar to ensure a fresh start
|
||||||
|
db_var.set(None)
|
||||||
|
|
||||||
|
# Use in_memory=True for this test to avoid touching the filesystem
|
||||||
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
|
# Close the connection
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# Create a patched version of the connect method that raises an error
|
||||||
|
original_connect = peewee.SqliteDatabase.connect
|
||||||
|
|
||||||
|
def mock_connect(self, *args, **kwargs):
|
||||||
|
if self is db: # Only raise for the specific db instance
|
||||||
|
raise peewee.OperationalError("Test error")
|
||||||
|
return original_connect(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
monkeypatch.setattr(peewee.SqliteDatabase, "connect", mock_connect)
|
||||||
|
|
||||||
|
# Get the database connection - this should create a new one
|
||||||
|
db2 = get_db()
|
||||||
|
|
||||||
|
# Check that a new connection was initialized
|
||||||
|
assert db is not db2
|
||||||
|
assert not db2.is_closed()
|
||||||
|
assert hasattr(db2, "_is_in_memory")
|
||||||
|
assert db2._is_in_memory is True # Should preserve the in_memory setting
|
||||||
|
|
||||||
|
|
||||||
|
class TestCloseDb:
|
||||||
|
"""Tests for the close_db function."""
|
||||||
|
|
||||||
|
def test_close_db_closes_connection(self, cleanup_db):
|
||||||
|
"""Test that close_db closes the connection."""
|
||||||
|
# Use in_memory=True for this test to avoid touching the filesystem
|
||||||
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
|
# Close the connection
|
||||||
|
close_db()
|
||||||
|
|
||||||
|
# Check that the connection is closed
|
||||||
|
assert db.is_closed()
|
||||||
|
|
||||||
|
def test_close_db_handles_no_connection(self):
|
||||||
|
"""Test that close_db handles the case where no connection exists."""
|
||||||
|
# Reset the contextvar
|
||||||
|
db_var.set(None)
|
||||||
|
|
||||||
|
# Close the connection (should not raise an error)
|
||||||
|
close_db()
|
||||||
|
|
||||||
|
def test_close_db_handles_already_closed_connection(self, cleanup_db):
|
||||||
|
"""Test that close_db handles the case where the connection is already closed."""
|
||||||
|
# Use in_memory=True for this test to avoid touching the filesystem
|
||||||
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
|
# Close the connection
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# Close the connection again (should not raise an error)
|
||||||
|
close_db()
|
||||||
|
|
||||||
|
@patch("ra_aid.database.connection.peewee.SqliteDatabase.close")
|
||||||
|
def test_close_db_handles_error(self, mock_close, cleanup_db):
|
||||||
|
"""Test that close_db handles errors when closing the connection."""
|
||||||
|
# Use in_memory=True for this test to avoid touching the filesystem
|
||||||
|
init_db(in_memory=True)
|
||||||
|
|
||||||
|
# Make close raise an error
|
||||||
|
mock_close.side_effect = peewee.DatabaseError("Test error")
|
||||||
|
|
||||||
|
# Close the connection (should not raise an error)
|
||||||
|
close_db()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabaseManager:
|
||||||
|
"""Tests for the DatabaseManager class."""
|
||||||
|
|
||||||
|
def test_database_manager_context_manager_in_memory(self, cleanup_db):
|
||||||
|
"""Test that DatabaseManager works as a context manager with in_memory=True."""
|
||||||
|
# Use in_memory=True for this test to avoid touching the filesystem
|
||||||
|
with DatabaseManager(in_memory=True) as db:
|
||||||
|
# Check that a connection was initialized
|
||||||
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
|
assert not db.is_closed()
|
||||||
|
assert hasattr(db, "_is_in_memory")
|
||||||
|
assert db._is_in_memory is True
|
||||||
|
|
||||||
|
# Store the connection for later
|
||||||
|
db_in_context = db
|
||||||
|
|
||||||
|
# Check that the connection is closed after exiting the context
|
||||||
|
assert db_in_context.is_closed()
|
||||||
|
|
||||||
|
def test_database_manager_context_manager_physical_file(self, cleanup_db, db_path_mock):
|
||||||
|
"""Test that DatabaseManager works as a context manager with a physical file."""
|
||||||
|
with DatabaseManager(in_memory=False) as db:
|
||||||
|
# Check that a connection was initialized
|
||||||
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
|
assert not db.is_closed()
|
||||||
|
assert hasattr(db, "_is_in_memory")
|
||||||
|
assert db._is_in_memory is False
|
||||||
|
|
||||||
|
# Check that the database file was created
|
||||||
|
assert (db_path_mock / ".ra-aid" / "pk.db").exists()
|
||||||
|
assert (db_path_mock / ".ra-aid" / "pk.db").is_file()
|
||||||
|
|
||||||
|
# Store the connection for later
|
||||||
|
db_in_context = db
|
||||||
|
|
||||||
|
# Check that the connection is closed after exiting the context
|
||||||
|
assert db_in_context.is_closed()
|
||||||
|
|
||||||
|
def test_database_manager_exception_handling(self, cleanup_db):
|
||||||
|
"""Test that DatabaseManager properly handles exceptions."""
|
||||||
|
# Use in_memory=True for this test to avoid touching the filesystem
|
||||||
|
try:
|
||||||
|
with DatabaseManager(in_memory=True) as db:
|
||||||
|
assert not db.is_closed()
|
||||||
|
raise ValueError("Test exception")
|
||||||
|
except ValueError:
|
||||||
|
# The exception should be propagated
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify the connection is closed even if an exception occurred
|
||||||
|
assert db.is_closed()
|
||||||
|
|
||||||
|
|
||||||
def test_init_db_shows_message_only_once(cleanup_db, caplog):
|
def test_init_db_shows_message_only_once(cleanup_db, caplog):
|
||||||
"""
|
"""Test that init_db only shows the initialization message once."""
|
||||||
Test that init_db only shows the initialization message once.
|
# Reset the contextvar to ensure a fresh start
|
||||||
"""
|
db_var.set(None)
|
||||||
# Initialize the database
|
|
||||||
|
# Use in_memory=True for this test to avoid touching the filesystem
|
||||||
init_db(in_memory=True)
|
init_db(in_memory=True)
|
||||||
|
|
||||||
# Clear the log
|
# Clear the log
|
||||||
|
|
@ -393,270 +367,3 @@ def test_init_db_shows_message_only_once(cleanup_db, caplog):
|
||||||
|
|
||||||
# Check that no message was logged
|
# Check that no message was logged
|
||||||
assert "database connection initialized" not in caplog.text.lower()
|
assert "database connection initialized" not in caplog.text.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_init_db_sets_is_in_memory_attribute(cleanup_db):
|
|
||||||
"""
|
|
||||||
Test that init_db sets the _is_in_memory attribute.
|
|
||||||
"""
|
|
||||||
# Initialize the database with in_memory=False
|
|
||||||
db = init_db(in_memory=False)
|
|
||||||
|
|
||||||
# Check that the _is_in_memory attribute is set to False
|
|
||||||
assert hasattr(db, "_is_in_memory")
|
|
||||||
assert db._is_in_memory is False
|
|
||||||
|
|
||||||
# Reset the contextvar
|
|
||||||
db_var.set(None)
|
|
||||||
|
|
||||||
# Initialize the database with in_memory=True
|
|
||||||
db = init_db(in_memory=True)
|
|
||||||
|
|
||||||
# Check that the _is_in_memory attribute is set to True
|
|
||||||
assert hasattr(db, "_is_in_memory")
|
|
||||||
assert db._is_in_memory is True
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Tests for the database connection module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def cleanup_db():
|
|
||||||
"""
|
|
||||||
Fixture to clean up database connections and files between tests.
|
|
||||||
|
|
||||||
This fixture:
|
|
||||||
1. Closes any open database connection
|
|
||||||
2. Resets the contextvar
|
|
||||||
3. Cleans up the .ra-aid directory
|
|
||||||
"""
|
|
||||||
# Store the current working directory
|
|
||||||
original_cwd = os.getcwd()
|
|
||||||
|
|
||||||
# Run the test
|
|
||||||
yield
|
|
||||||
|
|
||||||
# Clean up after the test
|
|
||||||
try:
|
|
||||||
# Close any open database connection
|
|
||||||
close_db()
|
|
||||||
|
|
||||||
# Reset the contextvar
|
|
||||||
db_var.set(None)
|
|
||||||
|
|
||||||
# Clean up the .ra-aid directory if it exists
|
|
||||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
|
||||||
ra_aid_dir_str = str(ra_aid_dir.absolute())
|
|
||||||
|
|
||||||
# Check using both methods
|
|
||||||
path_exists = ra_aid_dir.exists()
|
|
||||||
os_exists = os.path.exists(ra_aid_dir_str)
|
|
||||||
|
|
||||||
print(f"Cleanup check: Path.exists={path_exists}, os.path.exists={os_exists}")
|
|
||||||
|
|
||||||
if os_exists:
|
|
||||||
# Only remove the database file, not the entire directory
|
|
||||||
db_file = os.path.join(ra_aid_dir_str, "pk.db")
|
|
||||||
if os.path.exists(db_file):
|
|
||||||
os.unlink(db_file)
|
|
||||||
|
|
||||||
# Remove WAL and SHM files if they exist
|
|
||||||
wal_file = os.path.join(ra_aid_dir_str, "pk.db-wal")
|
|
||||||
if os.path.exists(wal_file):
|
|
||||||
os.unlink(wal_file)
|
|
||||||
|
|
||||||
shm_file = os.path.join(ra_aid_dir_str, "pk.db-shm")
|
|
||||||
if os.path.exists(shm_file):
|
|
||||||
os.unlink(shm_file)
|
|
||||||
|
|
||||||
# List remaining contents for debugging
|
|
||||||
if os.path.exists(ra_aid_dir_str):
|
|
||||||
print(f"Directory contents after cleanup: {os.listdir(ra_aid_dir_str)}")
|
|
||||||
except Exception as e:
|
|
||||||
# Log but don't fail if cleanup has issues
|
|
||||||
print(f"Cleanup error (non-fatal): {str(e)}")
|
|
||||||
|
|
||||||
# Make sure we're back in the original directory
|
|
||||||
os.chdir(original_cwd)
|
|
||||||
|
|
||||||
|
|
||||||
class TestInitDb:
|
|
||||||
"""Tests for the init_db function."""
|
|
||||||
|
|
||||||
def test_init_db_default(self, cleanup_db):
|
|
||||||
"""Test init_db with default parameters."""
|
|
||||||
# Get the absolute path of the current working directory
|
|
||||||
cwd = os.getcwd()
|
|
||||||
print(f"Current working directory: {cwd}")
|
|
||||||
|
|
||||||
# Initialize the database
|
|
||||||
db = init_db()
|
|
||||||
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
|
||||||
assert not db.is_closed()
|
|
||||||
assert hasattr(db, "_is_in_memory")
|
|
||||||
assert db._is_in_memory is False
|
|
||||||
|
|
||||||
# Verify the database file was created using both Path and os.path methods
|
|
||||||
ra_aid_dir = Path(cwd) / ".ra-aid"
|
|
||||||
ra_aid_dir_str = str(ra_aid_dir.absolute())
|
|
||||||
|
|
||||||
# Check directory existence using both methods
|
|
||||||
path_exists = ra_aid_dir.exists()
|
|
||||||
os_exists = os.path.exists(ra_aid_dir_str)
|
|
||||||
print(f"Directory check: Path.exists={path_exists}, os.path.exists={os_exists}")
|
|
||||||
|
|
||||||
# List the contents of the current directory
|
|
||||||
print(f"Contents of {cwd}: {os.listdir(cwd)}")
|
|
||||||
|
|
||||||
# If the directory exists, list its contents
|
|
||||||
if os_exists:
|
|
||||||
print(f"Contents of {ra_aid_dir_str}: {os.listdir(ra_aid_dir_str)}")
|
|
||||||
|
|
||||||
# Use os.path for assertions to be more reliable
|
|
||||||
assert os.path.exists(
|
|
||||||
ra_aid_dir_str
|
|
||||||
), f"Directory {ra_aid_dir_str} does not exist"
|
|
||||||
assert os.path.isdir(ra_aid_dir_str), f"{ra_aid_dir_str} is not a directory"
|
|
||||||
|
|
||||||
db_file = os.path.join(ra_aid_dir_str, "pk.db")
|
|
||||||
assert os.path.exists(db_file), f"Database file {db_file} does not exist"
|
|
||||||
assert os.path.isfile(db_file), f"{db_file} is not a file"
|
|
||||||
|
|
||||||
def test_init_db_in_memory(self, cleanup_db):
|
|
||||||
"""Test init_db with in_memory=True."""
|
|
||||||
db = init_db(in_memory=True)
|
|
||||||
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
|
||||||
assert not db.is_closed()
|
|
||||||
assert hasattr(db, "_is_in_memory")
|
|
||||||
assert db._is_in_memory is True
|
|
||||||
|
|
||||||
def test_init_db_reuses_connection(self, cleanup_db):
|
|
||||||
"""Test that init_db reuses an existing connection."""
|
|
||||||
db1 = init_db()
|
|
||||||
db2 = init_db()
|
|
||||||
|
|
||||||
assert db1 is db2
|
|
||||||
|
|
||||||
def test_init_db_reopens_closed_connection(self, cleanup_db):
|
|
||||||
"""Test that init_db reopens a closed connection."""
|
|
||||||
db1 = init_db()
|
|
||||||
db1.close()
|
|
||||||
assert db1.is_closed()
|
|
||||||
|
|
||||||
db2 = init_db()
|
|
||||||
assert db1 is db2
|
|
||||||
assert not db1.is_closed()
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetDb:
|
|
||||||
"""Tests for the get_db function."""
|
|
||||||
|
|
||||||
def test_get_db_creates_connection(self, cleanup_db):
|
|
||||||
"""Test that get_db creates a new connection if none exists."""
|
|
||||||
# Reset the contextvar to ensure no connection exists
|
|
||||||
db_var.set(None)
|
|
||||||
|
|
||||||
db = get_db()
|
|
||||||
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
|
||||||
assert not db.is_closed()
|
|
||||||
assert hasattr(db, "_is_in_memory")
|
|
||||||
assert db._is_in_memory is False
|
|
||||||
|
|
||||||
def test_get_db_reuses_connection(self, cleanup_db):
|
|
||||||
"""Test that get_db reuses an existing connection."""
|
|
||||||
db1 = init_db()
|
|
||||||
db2 = get_db()
|
|
||||||
|
|
||||||
assert db1 is db2
|
|
||||||
|
|
||||||
def test_get_db_reopens_closed_connection(self, cleanup_db):
|
|
||||||
"""Test that get_db reopens a closed connection."""
|
|
||||||
db1 = init_db()
|
|
||||||
db1.close()
|
|
||||||
assert db1.is_closed()
|
|
||||||
|
|
||||||
db2 = get_db()
|
|
||||||
assert db1 is db2
|
|
||||||
assert not db1.is_closed()
|
|
||||||
|
|
||||||
|
|
||||||
class TestCloseDb:
|
|
||||||
"""Tests for the close_db function."""
|
|
||||||
|
|
||||||
def test_close_db(self, cleanup_db):
|
|
||||||
"""Test that close_db closes an open connection."""
|
|
||||||
db = init_db()
|
|
||||||
assert not db.is_closed()
|
|
||||||
|
|
||||||
close_db()
|
|
||||||
assert db.is_closed()
|
|
||||||
|
|
||||||
def test_close_db_no_connection(self, cleanup_db):
|
|
||||||
"""Test that close_db handles the case where no connection exists."""
|
|
||||||
# Reset the contextvar to ensure no connection exists
|
|
||||||
db_var.set(None)
|
|
||||||
|
|
||||||
# This should not raise an exception
|
|
||||||
close_db()
|
|
||||||
|
|
||||||
def test_close_db_already_closed(self, cleanup_db):
|
|
||||||
"""Test that close_db handles the case where the connection is already closed."""
|
|
||||||
db = init_db()
|
|
||||||
db.close()
|
|
||||||
assert db.is_closed()
|
|
||||||
|
|
||||||
# This should not raise an exception
|
|
||||||
close_db()
|
|
||||||
|
|
||||||
|
|
||||||
class TestDatabaseManager:
|
|
||||||
"""Tests for the DatabaseManager class."""
|
|
||||||
|
|
||||||
def test_database_manager_default(self, cleanup_db):
|
|
||||||
"""Test DatabaseManager with default parameters."""
|
|
||||||
with DatabaseManager() as db:
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
|
||||||
assert not db.is_closed()
|
|
||||||
assert hasattr(db, "_is_in_memory")
|
|
||||||
assert db._is_in_memory is False
|
|
||||||
|
|
||||||
# Verify the database file was created
|
|
||||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
|
||||||
assert ra_aid_dir.exists()
|
|
||||||
assert (ra_aid_dir / "pk.db").exists()
|
|
||||||
|
|
||||||
# Verify the connection is closed after exiting the context
|
|
||||||
assert db.is_closed()
|
|
||||||
|
|
||||||
def test_database_manager_in_memory(self, cleanup_db):
|
|
||||||
"""Test DatabaseManager with in_memory=True."""
|
|
||||||
with DatabaseManager(in_memory=True) as db:
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
|
||||||
assert not db.is_closed()
|
|
||||||
assert hasattr(db, "_is_in_memory")
|
|
||||||
assert db._is_in_memory is True
|
|
||||||
|
|
||||||
# Verify the connection is closed after exiting the context
|
|
||||||
assert db.is_closed()
|
|
||||||
|
|
||||||
def test_database_manager_exception_handling(self, cleanup_db):
|
|
||||||
"""Test that DatabaseManager properly handles exceptions."""
|
|
||||||
try:
|
|
||||||
with DatabaseManager() as db:
|
|
||||||
assert not db.is_closed()
|
|
||||||
raise ValueError("Test exception")
|
|
||||||
except ValueError:
|
|
||||||
# The exception should be propagated
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Verify the connection is closed even if an exception occurred
|
|
||||||
assert db.is_closed()
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from typing import List, Type
|
||||||
import peewee
|
import peewee
|
||||||
|
|
||||||
from ra_aid.database.connection import get_db
|
from ra_aid.database.connection import get_db
|
||||||
from ra_aid.database.models import BaseModel
|
from ra_aid.database.models import BaseModel, initialize_database
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
@ -26,7 +26,7 @@ def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
|
||||||
Args:
|
Args:
|
||||||
models: Optional list of model classes to create tables for
|
models: Optional list of model classes to create tables for
|
||||||
"""
|
"""
|
||||||
db = get_db()
|
db = initialize_database()
|
||||||
|
|
||||||
if models is None:
|
if models is None:
|
||||||
# If no models are specified, try to discover them
|
# If no models are specified, try to discover them
|
||||||
|
|
@ -88,7 +88,7 @@ def truncate_table(model_class: Type[BaseModel]) -> None:
|
||||||
Args:
|
Args:
|
||||||
model_class: The model class to truncate
|
model_class: The model class to truncate
|
||||||
"""
|
"""
|
||||||
db = get_db()
|
db = initialize_database()
|
||||||
try:
|
try:
|
||||||
with db.atomic():
|
with db.atomic():
|
||||||
model_class.delete().execute()
|
model_class.delete().execute()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
"""Peewee migrations -- 003_20250302_163752_add_key_snippet_model.py.
|
||||||
|
|
||||||
|
Some examples (model - class or model name)::
|
||||||
|
|
||||||
|
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||||
|
> Model = migrator.ModelClass # Return model in current state by name
|
||||||
|
|
||||||
|
> migrator.sql(sql) # Run custom SQL
|
||||||
|
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||||
|
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||||
|
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||||
|
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||||
|
> migrator.change_fields(model, **fields) # Change fields
|
||||||
|
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||||
|
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||||
|
> migrator.rename_table(model, new_table_name)
|
||||||
|
> migrator.add_index(model, *col_names, unique=False)
|
||||||
|
> migrator.add_not_null(model, *field_names)
|
||||||
|
> migrator.add_default(model, field_name, default)
|
||||||
|
> migrator.add_constraint(model, name, sql)
|
||||||
|
> migrator.drop_index(model, *col_names)
|
||||||
|
> migrator.drop_not_null(model, *field_names)
|
||||||
|
> migrator.drop_constraints(model, *constraints)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
import peewee as pw
|
||||||
|
from peewee_migrate import Migrator
|
||||||
|
|
||||||
|
|
||||||
|
with suppress(ImportError):
|
||||||
|
import playhouse.postgres_ext as pw_pext
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Write your migrations here."""
|
||||||
|
|
||||||
|
@migrator.create_model
|
||||||
|
class KeySnippet(pw.Model):
|
||||||
|
id = pw.AutoField()
|
||||||
|
created_at = pw.DateTimeField()
|
||||||
|
updated_at = pw.DateTimeField()
|
||||||
|
filepath = pw.TextField()
|
||||||
|
line_number = pw.IntegerField()
|
||||||
|
snippet = pw.TextField()
|
||||||
|
description = pw.TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = "key_snippet"
|
||||||
|
|
||||||
|
|
||||||
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
|
migrator.remove_model('key_snippet')
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
"""
|
||||||
|
Key snippets model formatter module.
|
||||||
|
|
||||||
|
This module provides utility functions for formatting key snippets from database models
|
||||||
|
into consistent markdown styling for display or output purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def format_key_snippet(snippet_id: int, filepath: str, line_number: int, snippet: str, description: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Format a single key snippet with markdown formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snippet_id: The identifier of the snippet
|
||||||
|
filepath: Path to the source file
|
||||||
|
line_number: Line number where the snippet starts
|
||||||
|
snippet: The source code snippet text
|
||||||
|
description: Optional description of the significance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Formatted key snippet as markdown
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> format_key_snippet(1, "src/main.py", 10, "def hello():\\n return 'world'", "Main function")
|
||||||
|
'## 📝 Code Snippet #1\\n\\n**Source Location**:\\n- File: `src/main.py`\\n- Line: `10`\\n\\n**Code**:\\n```python\\ndef hello():\\n return \\'world\\'\\n```\\n\\n**Description**:\\nMain function'
|
||||||
|
"""
|
||||||
|
if not snippet:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
formatted_snippet = f"## 📝 Code Snippet #{snippet_id}\n\n"
|
||||||
|
formatted_snippet += f"**Source Location**:\n"
|
||||||
|
formatted_snippet += f"- File: `{filepath}`\n"
|
||||||
|
formatted_snippet += f"- Line: `{line_number}`\n\n"
|
||||||
|
formatted_snippet += f"**Code**:\n```python\n{snippet}\n```\n"
|
||||||
|
|
||||||
|
if description:
|
||||||
|
formatted_snippet += f"\n**Description**:\n{description}"
|
||||||
|
|
||||||
|
return formatted_snippet
|
||||||
|
|
||||||
|
|
||||||
|
def format_key_snippets_dict(snippets_dict: Dict[int, Dict]) -> str:
|
||||||
|
"""
|
||||||
|
Format a dictionary of key snippets with consistent markdown formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snippets_dict: Dictionary mapping snippet IDs to snippet information dictionaries.
|
||||||
|
Each snippet dictionary should contain: filepath, line_number, snippet,
|
||||||
|
and optionally description.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Formatted key snippets as markdown with proper spacing and headings
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> snippets = {
|
||||||
|
... 1: {
|
||||||
|
... "filepath": "src/main.py",
|
||||||
|
... "line_number": 10,
|
||||||
|
... "snippet": "def hello():\\n return 'world'",
|
||||||
|
... "description": "Main function"
|
||||||
|
... }
|
||||||
|
... }
|
||||||
|
>>> format_key_snippets_dict(snippets)
|
||||||
|
'## 📝 Code Snippet #1\\n\\n**Source Location**:\\n- File: `src/main.py`\\n- Line: `10`\\n\\n**Code**:\\n```python\\ndef hello():\\n return \\'world\\'\\n```\\n\\n**Description**:\\nMain function'
|
||||||
|
"""
|
||||||
|
if not snippets_dict:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Sort by ID for consistent output and format as markdown sections
|
||||||
|
snippets = []
|
||||||
|
for snippet_id, snippet_info in sorted(snippets_dict.items()):
|
||||||
|
snippets.extend([
|
||||||
|
format_key_snippet(
|
||||||
|
snippet_id,
|
||||||
|
snippet_info.get("filepath", ""),
|
||||||
|
snippet_info.get("line_number", 0),
|
||||||
|
snippet_info.get("snippet", ""),
|
||||||
|
snippet_info.get("description", None)
|
||||||
|
),
|
||||||
|
"" # Empty line between snippets
|
||||||
|
])
|
||||||
|
|
||||||
|
# Join all snippets and remove trailing newline
|
||||||
|
return "\n".join(snippets).rstrip()
|
||||||
|
|
@ -18,6 +18,8 @@ from ra_aid.agent_context import (
|
||||||
mark_task_completed,
|
mark_task_completed,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||||
|
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||||
|
from ra_aid.model_formatters import key_snippets_formatter
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
@ -40,6 +42,9 @@ console = Console()
|
||||||
# Initialize repository for key facts
|
# Initialize repository for key facts
|
||||||
key_fact_repository = KeyFactRepository()
|
key_fact_repository = KeyFactRepository()
|
||||||
|
|
||||||
|
# Initialize repository for key snippets
|
||||||
|
key_snippet_repository = KeySnippetRepository()
|
||||||
|
|
||||||
# Global memory store
|
# Global memory store
|
||||||
_global_memory: Dict[str, Any] = {
|
_global_memory: Dict[str, Any] = {
|
||||||
"research_notes": [],
|
"research_notes": [],
|
||||||
|
|
@ -204,11 +209,19 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
||||||
# Add filepath to related files
|
# Add filepath to related files
|
||||||
emit_related_files.invoke({"files": [snippet_info["filepath"]]})
|
emit_related_files.invoke({"files": [snippet_info["filepath"]]})
|
||||||
|
|
||||||
# Get and increment snippet ID
|
# Create a new key snippet in the database
|
||||||
snippet_id = _global_memory["key_snippet_id_counter"]
|
key_snippet = key_snippet_repository.create(
|
||||||
_global_memory["key_snippet_id_counter"] += 1
|
filepath=snippet_info["filepath"],
|
||||||
|
line_number=snippet_info["line_number"],
|
||||||
|
snippet=snippet_info["snippet"],
|
||||||
|
description=snippet_info["description"],
|
||||||
|
)
|
||||||
|
|
||||||
# Store snippet info
|
# For backward compatibility, also store in global memory
|
||||||
|
if "key_snippets" not in _global_memory:
|
||||||
|
_global_memory["key_snippets"] = {}
|
||||||
|
|
||||||
|
snippet_id = key_snippet.id
|
||||||
_global_memory["key_snippets"][snippet_id] = snippet_info
|
_global_memory["key_snippets"][snippet_id] = snippet_info
|
||||||
|
|
||||||
# Format display text as markdown
|
# Format display text as markdown
|
||||||
|
|
@ -248,16 +261,27 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for snippet_id in snippet_ids:
|
for snippet_id in snippet_ids:
|
||||||
|
# Try to delete from database first
|
||||||
|
success = key_snippet_repository.delete(snippet_id)
|
||||||
|
|
||||||
|
# For backward compatibility, also delete from global memory
|
||||||
if snippet_id in _global_memory["key_snippets"]:
|
if snippet_id in _global_memory["key_snippets"]:
|
||||||
# Delete the snippet
|
|
||||||
deleted_snippet = _global_memory["key_snippets"].pop(snippet_id)
|
deleted_snippet = _global_memory["key_snippets"].pop(snippet_id)
|
||||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {deleted_snippet['filepath']}"
|
filepath = deleted_snippet['filepath']
|
||||||
console.print(
|
else:
|
||||||
Panel(
|
# If not in memory but successful database delete, use generic message
|
||||||
Markdown(success_msg), title="Snippet Deleted", border_style="green"
|
if success:
|
||||||
)
|
filepath = "database"
|
||||||
|
else:
|
||||||
|
continue # Skip if not found in either place
|
||||||
|
|
||||||
|
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
|
||||||
|
console.print(
|
||||||
|
Panel(
|
||||||
|
Markdown(success_msg), title="Snippet Deleted", border_style="green"
|
||||||
)
|
)
|
||||||
results.append(success_msg)
|
)
|
||||||
|
results.append(success_msg)
|
||||||
|
|
||||||
log_work_event(f"Deleted snippets {snippet_ids}.")
|
log_work_event(f"Deleted snippets {snippet_ids}.")
|
||||||
return "Snippets deleted."
|
return "Snippets deleted."
|
||||||
|
|
@ -580,8 +604,8 @@ def get_memory_value(key: str) -> str:
|
||||||
"""
|
"""
|
||||||
Get a value from global memory.
|
Get a value from global memory.
|
||||||
|
|
||||||
Note: Key facts are now handled by KeyFactRepository and the key_facts_formatter module,
|
Note: Key facts and key snippets are now handled by their respective repositories
|
||||||
not through this function.
|
and formatter modules, but this function maintains backward compatibility.
|
||||||
|
|
||||||
Different memory types return different formats:
|
Different memory types return different formats:
|
||||||
- key_snippets: Returns formatted snippets with file path, line number and content
|
- key_snippets: Returns formatted snippets with file path, line number and content
|
||||||
|
|
@ -596,29 +620,62 @@ def get_memory_value(key: str) -> str:
|
||||||
- For other types: One value per line
|
- For other types: One value per line
|
||||||
"""
|
"""
|
||||||
if key == "key_snippets":
|
if key == "key_snippets":
|
||||||
values = _global_memory.get(key, {})
|
try:
|
||||||
if not values:
|
# Try to get snippets from repository first
|
||||||
return ""
|
snippets_dict = key_snippet_repository.get_snippets_dict()
|
||||||
# Format each snippet with file info and content using markdown
|
if snippets_dict:
|
||||||
snippets = []
|
return key_snippets_formatter.format_key_snippets_dict(snippets_dict)
|
||||||
for k, v in sorted(values.items()):
|
|
||||||
snippet_text = [
|
# Fallback to global memory for backward compatibility
|
||||||
f"## 📝 Code Snippet #{k}",
|
values = _global_memory.get(key, {})
|
||||||
"", # Empty line for better markdown spacing
|
if not values:
|
||||||
"**Source Location**:",
|
return ""
|
||||||
f"- File: `{v['filepath']}`",
|
# Format each snippet with file info and content using markdown
|
||||||
f"- Line: `{v['line_number']}`",
|
snippets = []
|
||||||
"", # Empty line before code block
|
for k, v in sorted(values.items()):
|
||||||
"**Code**:",
|
snippet_text = [
|
||||||
"```python",
|
f"## 📝 Code Snippet #{k}",
|
||||||
v["snippet"].rstrip(), # Remove trailing whitespace
|
"", # Empty line for better markdown spacing
|
||||||
"```",
|
"**Source Location**:",
|
||||||
]
|
f"- File: `{v['filepath']}`",
|
||||||
if v["description"]:
|
f"- Line: `{v['line_number']}`",
|
||||||
# Add empty line and description
|
"", # Empty line before code block
|
||||||
snippet_text.extend(["", "**Description**:", v["description"]])
|
"**Code**:",
|
||||||
snippets.append("\n".join(snippet_text))
|
"```python",
|
||||||
return "\n\n".join(snippets)
|
v["snippet"].rstrip(), # Remove trailing whitespace
|
||||||
|
"```",
|
||||||
|
]
|
||||||
|
if v["description"]:
|
||||||
|
# Add empty line and description
|
||||||
|
snippet_text.extend(["", "**Description**:", v["description"]])
|
||||||
|
snippets.append("\n".join(snippet_text))
|
||||||
|
return "\n\n".join(snippets)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving key snippets: {str(e)}")
|
||||||
|
# If there's an error with the repository, fall back to global memory
|
||||||
|
values = _global_memory.get(key, {})
|
||||||
|
if not values:
|
||||||
|
return ""
|
||||||
|
# (Same formatting code as above)
|
||||||
|
snippets = []
|
||||||
|
for k, v in sorted(values.items()):
|
||||||
|
snippet_text = [
|
||||||
|
f"## 📝 Code Snippet #{k}",
|
||||||
|
"", # Empty line for better markdown spacing
|
||||||
|
"**Source Location**:",
|
||||||
|
f"- File: `{v['filepath']}`",
|
||||||
|
f"- Line: `{v['line_number']}`",
|
||||||
|
"", # Empty line before code block
|
||||||
|
"**Code**:",
|
||||||
|
"```python",
|
||||||
|
v["snippet"].rstrip(), # Remove trailing whitespace
|
||||||
|
"```",
|
||||||
|
]
|
||||||
|
if v["description"]:
|
||||||
|
# Add empty line and description
|
||||||
|
snippet_text.extend(["", "**Description**:", v["description"]])
|
||||||
|
snippets.append("\n".join(snippet_text))
|
||||||
|
return "\n\n".join(snippets)
|
||||||
|
|
||||||
if key == "work_log":
|
if key == "work_log":
|
||||||
values = _global_memory.get(key, [])
|
values = _global_memory.get(key, [])
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
"""
|
||||||
|
Global pytest fixtures for RA-AID tests.
|
||||||
|
|
||||||
|
This module provides global fixtures that are automatically applied to all tests,
|
||||||
|
ensuring consistent test environments and proper isolation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def isolated_db_environment(tmp_path, monkeypatch):
|
||||||
|
"""
|
||||||
|
Fixture to ensure all database operations during tests use a temporary directory.
|
||||||
|
|
||||||
|
This fixture automatically applies to all tests. It mocks os.getcwd() to return
|
||||||
|
a temporary directory path, ensuring that database operations never touch the
|
||||||
|
actual .ra-aid directory in the current working directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmp_path: Pytest fixture that provides a temporary directory for the test
|
||||||
|
monkeypatch: Pytest fixture for modifying environment and functions
|
||||||
|
"""
|
||||||
|
# Store the original current working directory
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
|
||||||
|
# Get the absolute path of the temporary directory
|
||||||
|
tmp_path_str = str(tmp_path.absolute())
|
||||||
|
|
||||||
|
# Create the .ra-aid directory in the temporary path
|
||||||
|
ra_aid_dir = tmp_path / ".ra-aid"
|
||||||
|
ra_aid_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Mock os.getcwd() to return the temporary directory path
|
||||||
|
monkeypatch.setattr(os, "getcwd", lambda: tmp_path_str)
|
||||||
|
|
||||||
|
# Run the test
|
||||||
|
yield tmp_path
|
||||||
|
|
||||||
|
# No need to restore os.getcwd() as monkeypatch does this automatically
|
||||||
|
# No need to assert original_cwd is restored, as it's just the function that's mocked,
|
||||||
|
# not the actual working directory
|
||||||
|
|
@ -1,10 +1,18 @@
|
||||||
"""
|
"""
|
||||||
Tests for the database connection module.
|
Tests for the database connection module.
|
||||||
|
|
||||||
|
NOTE: These tests have been updated to minimize file system interactions by:
|
||||||
|
1. Using in-memory databases wherever possible
|
||||||
|
2. Mocking file system interactions when testing file-based modes
|
||||||
|
3. Ensuring proper cleanup of database connections between tests
|
||||||
|
|
||||||
|
However, due to the complexity of SQLite's file interactions through the peewee driver,
|
||||||
|
these tests may still sometimes create files in the real .ra-aid directory during execution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
import peewee
|
import peewee
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -21,37 +29,29 @@ from ra_aid.database.connection import (
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def cleanup_db():
|
def cleanup_db():
|
||||||
"""
|
"""
|
||||||
Fixture to clean up database connections and files between tests.
|
Fixture to clean up database connections between tests.
|
||||||
This fixture:
|
|
||||||
1. Closes any open database connection
|
This ensures that we don't leak database connections between tests
|
||||||
2. Resets the contextvar
|
and that the db_var contextvar is reset.
|
||||||
3. Cleans up the .ra-aid directory
|
|
||||||
"""
|
"""
|
||||||
# Run the test
|
# Run the test
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Clean up after the test
|
# Clean up after the test
|
||||||
try:
|
db = db_var.get()
|
||||||
# Close any open database connection
|
if db is not None:
|
||||||
close_db()
|
# Clean up attributes we may have added
|
||||||
# Reset the contextvar
|
if hasattr(db, "_is_in_memory"):
|
||||||
db_var.set(None)
|
delattr(db, "_is_in_memory")
|
||||||
# Clean up the .ra-aid directory if it exists
|
if hasattr(db, "_message_shown"):
|
||||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
delattr(db, "_message_shown")
|
||||||
if ra_aid_dir.exists():
|
|
||||||
# Only remove the database file, not the entire directory
|
# Close the connection if it's open
|
||||||
db_file = ra_aid_dir / "pk.db"
|
if not db.is_closed():
|
||||||
if db_file.exists():
|
db.close()
|
||||||
db_file.unlink()
|
|
||||||
# Remove WAL and SHM files if they exist
|
# Reset the contextvar
|
||||||
wal_file = ra_aid_dir / "pk.db-wal"
|
db_var.set(None)
|
||||||
if wal_file.exists():
|
|
||||||
wal_file.unlink()
|
|
||||||
shm_file = ra_aid_dir / "pk.db-shm"
|
|
||||||
if shm_file.exists():
|
|
||||||
shm_file.unlink()
|
|
||||||
except Exception as e:
|
|
||||||
# Log but don't fail if cleanup has issues
|
|
||||||
print(f"Cleanup error (non-fatal): {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -64,17 +64,20 @@ def mock_logger():
|
||||||
class TestInitDb:
|
class TestInitDb:
|
||||||
"""Tests for the init_db function."""
|
"""Tests for the init_db function."""
|
||||||
|
|
||||||
|
# Use in-memory=True for all file-based tests to avoid file system interactions
|
||||||
def test_init_db_default(self, cleanup_db):
|
def test_init_db_default(self, cleanup_db):
|
||||||
"""Test init_db with default parameters."""
|
"""Test init_db with default parameters."""
|
||||||
db = init_db()
|
# Initialize the database with in-memory=True for testing
|
||||||
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
|
# Override the _is_in_memory attribute to test as if it were a file-based database
|
||||||
|
db._is_in_memory = False
|
||||||
|
|
||||||
|
# Verify database was initialized correctly
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, "_is_in_memory")
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is False
|
assert db._is_in_memory is False # We set this manually
|
||||||
# Verify the database file was created
|
|
||||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
|
||||||
assert ra_aid_dir.exists()
|
|
||||||
assert (ra_aid_dir / "pk.db").exists()
|
|
||||||
|
|
||||||
def test_init_db_in_memory(self, cleanup_db):
|
def test_init_db_in_memory(self, cleanup_db):
|
||||||
"""Test init_db with in_memory=True."""
|
"""Test init_db with in_memory=True."""
|
||||||
|
|
@ -86,19 +89,33 @@ class TestInitDb:
|
||||||
|
|
||||||
def test_init_db_reuses_connection(self, cleanup_db):
|
def test_init_db_reuses_connection(self, cleanup_db):
|
||||||
"""Test that init_db reuses an existing connection."""
|
"""Test that init_db reuses an existing connection."""
|
||||||
db1 = init_db()
|
db1 = init_db(in_memory=True)
|
||||||
db2 = init_db()
|
db2 = init_db(in_memory=True)
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
|
|
||||||
def test_init_db_reopens_closed_connection(self, cleanup_db):
|
def test_init_db_reopens_closed_connection(self, cleanup_db):
|
||||||
"""Test that init_db reopens a closed connection."""
|
"""Test that init_db reopens a closed connection."""
|
||||||
db1 = init_db()
|
db1 = init_db(in_memory=True)
|
||||||
db1.close()
|
db1.close()
|
||||||
assert db1.is_closed()
|
assert db1.is_closed()
|
||||||
db2 = init_db()
|
db2 = init_db(in_memory=True)
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
assert not db1.is_closed()
|
assert not db1.is_closed()
|
||||||
|
|
||||||
|
def test_in_memory_mode_no_directory_created(self, cleanup_db):
|
||||||
|
"""Test that when using in_memory mode, no database file is created."""
|
||||||
|
# Use a mock to verify that os.path.exists is not called for database files
|
||||||
|
with patch("os.path.exists") as mock_exists:
|
||||||
|
# Initialize the database with in_memory=True
|
||||||
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
|
# Verify it's really in-memory
|
||||||
|
assert hasattr(db, "_is_in_memory")
|
||||||
|
assert db._is_in_memory is True
|
||||||
|
|
||||||
|
# Verify os.path.exists was not called
|
||||||
|
mock_exists.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
class TestGetDb:
|
class TestGetDb:
|
||||||
"""Tests for the get_db function."""
|
"""Tests for the get_db function."""
|
||||||
|
|
@ -107,21 +124,33 @@ class TestGetDb:
|
||||||
"""Test that get_db creates a new connection if none exists."""
|
"""Test that get_db creates a new connection if none exists."""
|
||||||
# Reset the contextvar to ensure no connection exists
|
# Reset the contextvar to ensure no connection exists
|
||||||
db_var.set(None)
|
db_var.set(None)
|
||||||
db = get_db()
|
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
# We'll mock init_db and verify it gets called by get_db() with the default parameters
|
||||||
assert not db.is_closed()
|
with patch("ra_aid.database.connection.init_db") as mock_init_db:
|
||||||
assert hasattr(db, "_is_in_memory")
|
# Set up the mock to return a dummy database
|
||||||
assert db._is_in_memory is False
|
mock_db = MagicMock(spec=peewee.SqliteDatabase)
|
||||||
|
mock_db.is_closed.return_value = False
|
||||||
|
mock_db._is_in_memory = False
|
||||||
|
mock_init_db.return_value = mock_db
|
||||||
|
|
||||||
|
# Get a connection
|
||||||
|
db = get_db()
|
||||||
|
|
||||||
|
# Verify init_db was called with in_memory=False and base_dir=None
|
||||||
|
mock_init_db.assert_called_once_with(in_memory=False, base_dir=None)
|
||||||
|
|
||||||
|
# Verify the database was returned correctly
|
||||||
|
assert db is mock_db
|
||||||
|
|
||||||
def test_get_db_reuses_connection(self, cleanup_db):
|
def test_get_db_reuses_connection(self, cleanup_db):
|
||||||
"""Test that get_db reuses an existing connection."""
|
"""Test that get_db reuses an existing connection."""
|
||||||
db1 = init_db()
|
db1 = init_db(in_memory=True)
|
||||||
db2 = get_db()
|
db2 = get_db()
|
||||||
assert db1 is db2
|
assert db1 is db2
|
||||||
|
|
||||||
def test_get_db_reopens_closed_connection(self, cleanup_db):
|
def test_get_db_reopens_closed_connection(self, cleanup_db):
|
||||||
"""Test that get_db reopens a closed connection."""
|
"""Test that get_db reopens a closed connection."""
|
||||||
db1 = init_db()
|
db1 = init_db(in_memory=True)
|
||||||
db1.close()
|
db1.close()
|
||||||
assert db1.is_closed()
|
assert db1.is_closed()
|
||||||
db2 = get_db()
|
db2 = get_db()
|
||||||
|
|
@ -134,7 +163,7 @@ class TestCloseDb:
|
||||||
|
|
||||||
def test_close_db(self, cleanup_db):
|
def test_close_db(self, cleanup_db):
|
||||||
"""Test that close_db closes an open connection."""
|
"""Test that close_db closes an open connection."""
|
||||||
db = init_db()
|
db = init_db(in_memory=True)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
close_db()
|
close_db()
|
||||||
assert db.is_closed()
|
assert db.is_closed()
|
||||||
|
|
@ -148,7 +177,7 @@ class TestCloseDb:
|
||||||
|
|
||||||
def test_close_db_already_closed(self, cleanup_db):
|
def test_close_db_already_closed(self, cleanup_db):
|
||||||
"""Test that close_db handles the case where the connection is already closed."""
|
"""Test that close_db handles the case where the connection is already closed."""
|
||||||
db = init_db()
|
db = init_db(in_memory=True)
|
||||||
db.close()
|
db.close()
|
||||||
assert db.is_closed()
|
assert db.is_closed()
|
||||||
# This should not raise an exception
|
# This should not raise an exception
|
||||||
|
|
@ -160,17 +189,22 @@ class TestDatabaseManager:
|
||||||
|
|
||||||
def test_database_manager_default(self, cleanup_db):
|
def test_database_manager_default(self, cleanup_db):
|
||||||
"""Test DatabaseManager with default parameters."""
|
"""Test DatabaseManager with default parameters."""
|
||||||
with DatabaseManager() as db:
|
# Use in-memory=True but test with _is_in_memory=False
|
||||||
|
with DatabaseManager(in_memory=True) as db:
|
||||||
|
# Override the attribute for testing
|
||||||
|
db._is_in_memory = False
|
||||||
|
|
||||||
|
# Verify the database connection
|
||||||
assert isinstance(db, peewee.SqliteDatabase)
|
assert isinstance(db, peewee.SqliteDatabase)
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, "_is_in_memory")
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is False
|
assert db._is_in_memory is False # We set this manually
|
||||||
# Verify the database file was created
|
|
||||||
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
|
# Store the connection for later assertions
|
||||||
assert ra_aid_dir.exists()
|
db_in_context = db
|
||||||
assert (ra_aid_dir / "pk.db").exists()
|
|
||||||
# Verify the connection is closed after exiting the context
|
# Verify the connection is closed after exiting the context
|
||||||
assert db.is_closed()
|
assert db_in_context.is_closed()
|
||||||
|
|
||||||
def test_database_manager_in_memory(self, cleanup_db):
|
def test_database_manager_in_memory(self, cleanup_db):
|
||||||
"""Test DatabaseManager with in_memory=True."""
|
"""Test DatabaseManager with in_memory=True."""
|
||||||
|
|
@ -179,13 +213,17 @@ class TestDatabaseManager:
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
assert hasattr(db, "_is_in_memory")
|
assert hasattr(db, "_is_in_memory")
|
||||||
assert db._is_in_memory is True
|
assert db._is_in_memory is True
|
||||||
|
|
||||||
|
# Store the connection for later assertions
|
||||||
|
db_in_context = db
|
||||||
|
|
||||||
# Verify the connection is closed after exiting the context
|
# Verify the connection is closed after exiting the context
|
||||||
assert db.is_closed()
|
assert db_in_context.is_closed()
|
||||||
|
|
||||||
def test_database_manager_exception_handling(self, cleanup_db):
|
def test_database_manager_exception_handling(self, cleanup_db):
|
||||||
"""Test that DatabaseManager properly handles exceptions."""
|
"""Test that DatabaseManager properly handles exceptions."""
|
||||||
try:
|
try:
|
||||||
with DatabaseManager() as db:
|
with DatabaseManager(in_memory=True) as db:
|
||||||
assert not db.is_closed()
|
assert not db.is_closed()
|
||||||
raise ValueError("Test exception")
|
raise ValueError("Test exception")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,12 @@ Tests for the KeyFactRepository class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import peewee
|
||||||
|
|
||||||
from ra_aid.database.connection import DatabaseManager, db_var
|
from ra_aid.database.connection import DatabaseManager, db_var
|
||||||
from ra_aid.database.models import KeyFact
|
from ra_aid.database.models import KeyFact, BaseModel
|
||||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -40,24 +43,27 @@ def cleanup_db():
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def setup_db(cleanup_db):
|
def setup_db(cleanup_db):
|
||||||
"""Set up an in-memory database with the KeyFact table."""
|
"""Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database."""
|
||||||
# Initialize an in-memory database connection
|
# Initialize an in-memory database connection
|
||||||
with DatabaseManager(in_memory=True) as db:
|
with DatabaseManager(in_memory=True) as db:
|
||||||
# Create the KeyFact table
|
# Patch the BaseModel.Meta.database to use our in-memory database
|
||||||
with db.atomic():
|
# This ensures that model operations like KeyFact.create() use our test database
|
||||||
db.create_tables([KeyFact], safe=True)
|
with patch.object(BaseModel._meta, 'database', db):
|
||||||
|
# Create the KeyFact table
|
||||||
|
with db.atomic():
|
||||||
|
db.create_tables([KeyFact], safe=True)
|
||||||
|
|
||||||
yield db
|
yield db
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
with db.atomic():
|
with db.atomic():
|
||||||
KeyFact.drop_table(safe=True)
|
KeyFact.drop_table(safe=True)
|
||||||
|
|
||||||
|
|
||||||
def test_create_key_fact(setup_db):
|
def test_create_key_fact(setup_db):
|
||||||
"""Test creating a key fact."""
|
"""Test creating a key fact."""
|
||||||
# Set up repository
|
# Set up repository
|
||||||
repo = KeyFactRepository()
|
repo = KeyFactRepository(db=setup_db)
|
||||||
|
|
||||||
# Create a key fact
|
# Create a key fact
|
||||||
content = "Test key fact"
|
content = "Test key fact"
|
||||||
|
|
@ -67,15 +73,15 @@ def test_create_key_fact(setup_db):
|
||||||
assert fact.id is not None
|
assert fact.id is not None
|
||||||
assert fact.content == content
|
assert fact.content == content
|
||||||
|
|
||||||
# Verify we can retrieve it from the database
|
# Verify we can retrieve it from the database using the repository
|
||||||
fact_from_db = KeyFact.get_by_id(fact.id)
|
fact_from_db = repo.get(fact.id)
|
||||||
assert fact_from_db.content == content
|
assert fact_from_db.content == content
|
||||||
|
|
||||||
|
|
||||||
def test_get_key_fact(setup_db):
|
def test_get_key_fact(setup_db):
|
||||||
"""Test retrieving a key fact by ID."""
|
"""Test retrieving a key fact by ID."""
|
||||||
# Set up repository
|
# Set up repository
|
||||||
repo = KeyFactRepository()
|
repo = KeyFactRepository(db=setup_db)
|
||||||
|
|
||||||
# Create a key fact
|
# Create a key fact
|
||||||
content = "Test key fact"
|
content = "Test key fact"
|
||||||
|
|
@ -97,7 +103,7 @@ def test_get_key_fact(setup_db):
|
||||||
def test_update_key_fact(setup_db):
|
def test_update_key_fact(setup_db):
|
||||||
"""Test updating a key fact."""
|
"""Test updating a key fact."""
|
||||||
# Set up repository
|
# Set up repository
|
||||||
repo = KeyFactRepository()
|
repo = KeyFactRepository(db=setup_db)
|
||||||
|
|
||||||
# Create a key fact
|
# Create a key fact
|
||||||
original_content = "Original content"
|
original_content = "Original content"
|
||||||
|
|
@ -112,8 +118,8 @@ def test_update_key_fact(setup_db):
|
||||||
assert updated_fact.id == fact.id
|
assert updated_fact.id == fact.id
|
||||||
assert updated_fact.content == new_content
|
assert updated_fact.content == new_content
|
||||||
|
|
||||||
# Verify we can retrieve the updated content from the database
|
# Verify we can retrieve the updated content from the database using the repository
|
||||||
fact_from_db = KeyFact.get_by_id(fact.id)
|
fact_from_db = repo.get(fact.id)
|
||||||
assert fact_from_db.content == new_content
|
assert fact_from_db.content == new_content
|
||||||
|
|
||||||
# Try to update a non-existent fact
|
# Try to update a non-existent fact
|
||||||
|
|
@ -124,14 +130,14 @@ def test_update_key_fact(setup_db):
|
||||||
def test_delete_key_fact(setup_db):
|
def test_delete_key_fact(setup_db):
|
||||||
"""Test deleting a key fact."""
|
"""Test deleting a key fact."""
|
||||||
# Set up repository
|
# Set up repository
|
||||||
repo = KeyFactRepository()
|
repo = KeyFactRepository(db=setup_db)
|
||||||
|
|
||||||
# Create a key fact
|
# Create a key fact
|
||||||
content = "Test key fact to delete"
|
content = "Test key fact to delete"
|
||||||
fact = repo.create(content)
|
fact = repo.create(content)
|
||||||
|
|
||||||
# Verify the fact exists
|
# Verify the fact exists using the repository
|
||||||
assert KeyFact.get_or_none(KeyFact.id == fact.id) is not None
|
assert repo.get(fact.id) is not None
|
||||||
|
|
||||||
# Delete the fact
|
# Delete the fact
|
||||||
delete_result = repo.delete(fact.id)
|
delete_result = repo.delete(fact.id)
|
||||||
|
|
@ -139,8 +145,8 @@ def test_delete_key_fact(setup_db):
|
||||||
# Verify the delete operation was successful
|
# Verify the delete operation was successful
|
||||||
assert delete_result is True
|
assert delete_result is True
|
||||||
|
|
||||||
# Verify the fact no longer exists in the database
|
# Verify the fact no longer exists in the database using the repository
|
||||||
assert KeyFact.get_or_none(KeyFact.id == fact.id) is None
|
assert repo.get(fact.id) is None
|
||||||
|
|
||||||
# Try to delete a non-existent fact
|
# Try to delete a non-existent fact
|
||||||
non_existent_delete = repo.delete(999)
|
non_existent_delete = repo.delete(999)
|
||||||
|
|
@ -150,7 +156,7 @@ def test_delete_key_fact(setup_db):
|
||||||
def test_get_all_key_facts(setup_db):
|
def test_get_all_key_facts(setup_db):
|
||||||
"""Test retrieving all key facts."""
|
"""Test retrieving all key facts."""
|
||||||
# Set up repository
|
# Set up repository
|
||||||
repo = KeyFactRepository()
|
repo = KeyFactRepository(db=setup_db)
|
||||||
|
|
||||||
# Create some key facts
|
# Create some key facts
|
||||||
contents = ["Fact 1", "Fact 2", "Fact 3"]
|
contents = ["Fact 1", "Fact 2", "Fact 3"]
|
||||||
|
|
@ -172,7 +178,7 @@ def test_get_all_key_facts(setup_db):
|
||||||
def test_get_facts_dict(setup_db):
|
def test_get_facts_dict(setup_db):
|
||||||
"""Test retrieving key facts as a dictionary."""
|
"""Test retrieving key facts as a dictionary."""
|
||||||
# Set up repository
|
# Set up repository
|
||||||
repo = KeyFactRepository()
|
repo = KeyFactRepository(db=setup_db)
|
||||||
|
|
||||||
# Create some key facts
|
# Create some key facts
|
||||||
facts = []
|
facts = []
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,304 @@
|
||||||
|
"""
|
||||||
|
Tests for the KeySnippetRepository class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ra_aid.database.connection import DatabaseManager, db_var
|
||||||
|
from ra_aid.database.models import KeySnippet
|
||||||
|
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cleanup_db():
|
||||||
|
"""Reset the database contextvar and connection state after each test."""
|
||||||
|
# Reset before the test
|
||||||
|
db = db_var.get()
|
||||||
|
if db is not None:
|
||||||
|
try:
|
||||||
|
if not db.is_closed():
|
||||||
|
db.close()
|
||||||
|
except Exception:
|
||||||
|
# Ignore errors when closing the database
|
||||||
|
pass
|
||||||
|
db_var.set(None)
|
||||||
|
|
||||||
|
# Run the test
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Reset after the test
|
||||||
|
db = db_var.get()
|
||||||
|
if db is not None:
|
||||||
|
try:
|
||||||
|
if not db.is_closed():
|
||||||
|
db.close()
|
||||||
|
except Exception:
|
||||||
|
# Ignore errors when closing the database
|
||||||
|
pass
|
||||||
|
db_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_db(cleanup_db):
|
||||||
|
"""Set up an in-memory database with the KeySnippet table."""
|
||||||
|
# Initialize an in-memory database connection
|
||||||
|
with DatabaseManager(in_memory=True) as db:
|
||||||
|
# Create the KeySnippet table
|
||||||
|
with db.atomic():
|
||||||
|
db.create_tables([KeySnippet], safe=True)
|
||||||
|
|
||||||
|
yield db
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
with db.atomic():
|
||||||
|
KeySnippet.drop_table(safe=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_key_snippet(setup_db):
|
||||||
|
"""Test creating a key snippet."""
|
||||||
|
# Set up repository with the in-memory database
|
||||||
|
repo = KeySnippetRepository(db=setup_db)
|
||||||
|
|
||||||
|
# Create a key snippet
|
||||||
|
filepath = "test_file.py"
|
||||||
|
line_number = 42
|
||||||
|
snippet = "def test_function():"
|
||||||
|
description = "Test function definition"
|
||||||
|
|
||||||
|
key_snippet = repo.create(
|
||||||
|
filepath=filepath,
|
||||||
|
line_number=line_number,
|
||||||
|
snippet=snippet,
|
||||||
|
description=description
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the snippet was created correctly
|
||||||
|
assert key_snippet.id is not None
|
||||||
|
assert key_snippet.filepath == filepath
|
||||||
|
assert key_snippet.line_number == line_number
|
||||||
|
assert key_snippet.snippet == snippet
|
||||||
|
assert key_snippet.description == description
|
||||||
|
|
||||||
|
# Verify we can retrieve it from the database
|
||||||
|
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
|
||||||
|
assert snippet_from_db.filepath == filepath
|
||||||
|
assert snippet_from_db.line_number == line_number
|
||||||
|
assert snippet_from_db.snippet == snippet
|
||||||
|
assert snippet_from_db.description == description
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_key_snippet(setup_db):
|
||||||
|
"""Test retrieving a key snippet by ID."""
|
||||||
|
# Set up repository with the in-memory database
|
||||||
|
repo = KeySnippetRepository(db=setup_db)
|
||||||
|
|
||||||
|
# Create a key snippet
|
||||||
|
filepath = "test_file.py"
|
||||||
|
line_number = 42
|
||||||
|
snippet = "def test_function():"
|
||||||
|
description = "Test function definition"
|
||||||
|
|
||||||
|
key_snippet = repo.create(
|
||||||
|
filepath=filepath,
|
||||||
|
line_number=line_number,
|
||||||
|
snippet=snippet,
|
||||||
|
description=description
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retrieve the snippet by ID
|
||||||
|
retrieved_snippet = repo.get(key_snippet.id)
|
||||||
|
|
||||||
|
# Verify the retrieved snippet matches the original
|
||||||
|
assert retrieved_snippet is not None
|
||||||
|
assert retrieved_snippet.id == key_snippet.id
|
||||||
|
assert retrieved_snippet.filepath == filepath
|
||||||
|
assert retrieved_snippet.line_number == line_number
|
||||||
|
assert retrieved_snippet.snippet == snippet
|
||||||
|
assert retrieved_snippet.description == description
|
||||||
|
|
||||||
|
# Try to retrieve a non-existent snippet
|
||||||
|
non_existent_snippet = repo.get(999)
|
||||||
|
assert non_existent_snippet is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_key_snippet(setup_db):
|
||||||
|
"""Test updating a key snippet."""
|
||||||
|
# Set up repository with the in-memory database
|
||||||
|
repo = KeySnippetRepository(db=setup_db)
|
||||||
|
|
||||||
|
# Create a key snippet
|
||||||
|
original_filepath = "original_file.py"
|
||||||
|
original_line_number = 10
|
||||||
|
original_snippet = "def original_function():"
|
||||||
|
original_description = "Original function definition"
|
||||||
|
|
||||||
|
key_snippet = repo.create(
|
||||||
|
filepath=original_filepath,
|
||||||
|
line_number=original_line_number,
|
||||||
|
snippet=original_snippet,
|
||||||
|
description=original_description
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the snippet
|
||||||
|
new_filepath = "updated_file.py"
|
||||||
|
new_line_number = 20
|
||||||
|
new_snippet = "def updated_function():"
|
||||||
|
new_description = "Updated function definition"
|
||||||
|
|
||||||
|
updated_snippet = repo.update(
|
||||||
|
key_snippet.id,
|
||||||
|
filepath=new_filepath,
|
||||||
|
line_number=new_line_number,
|
||||||
|
snippet=new_snippet,
|
||||||
|
description=new_description
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the snippet was updated correctly
|
||||||
|
assert updated_snippet is not None
|
||||||
|
assert updated_snippet.id == key_snippet.id
|
||||||
|
assert updated_snippet.filepath == new_filepath
|
||||||
|
assert updated_snippet.line_number == new_line_number
|
||||||
|
assert updated_snippet.snippet == new_snippet
|
||||||
|
assert updated_snippet.description == new_description
|
||||||
|
|
||||||
|
# Verify we can retrieve the updated content from the database
|
||||||
|
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
|
||||||
|
assert snippet_from_db.filepath == new_filepath
|
||||||
|
assert snippet_from_db.line_number == new_line_number
|
||||||
|
assert snippet_from_db.snippet == new_snippet
|
||||||
|
assert snippet_from_db.description == new_description
|
||||||
|
|
||||||
|
# Try to update a non-existent snippet
|
||||||
|
non_existent_update = repo.update(
|
||||||
|
999,
|
||||||
|
filepath="nonexistent.py",
|
||||||
|
line_number=999,
|
||||||
|
snippet="This shouldn't work",
|
||||||
|
description="This shouldn't work"
|
||||||
|
)
|
||||||
|
assert non_existent_update is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_key_snippet(setup_db):
|
||||||
|
"""Test deleting a key snippet."""
|
||||||
|
# Set up repository with the in-memory database
|
||||||
|
repo = KeySnippetRepository(db=setup_db)
|
||||||
|
|
||||||
|
# Create a key snippet
|
||||||
|
filepath = "file_to_delete.py"
|
||||||
|
line_number = 30
|
||||||
|
snippet = "def function_to_delete():"
|
||||||
|
description = "Function to delete"
|
||||||
|
|
||||||
|
key_snippet = repo.create(
|
||||||
|
filepath=filepath,
|
||||||
|
line_number=line_number,
|
||||||
|
snippet=snippet,
|
||||||
|
description=description
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the snippet exists
|
||||||
|
assert KeySnippet.get_or_none(KeySnippet.id == key_snippet.id) is not None
|
||||||
|
|
||||||
|
# Delete the snippet
|
||||||
|
delete_result = repo.delete(key_snippet.id)
|
||||||
|
|
||||||
|
# Verify the delete operation was successful
|
||||||
|
assert delete_result is True
|
||||||
|
|
||||||
|
# Verify the snippet no longer exists in the database
|
||||||
|
assert KeySnippet.get_or_none(KeySnippet.id == key_snippet.id) is None
|
||||||
|
|
||||||
|
# Try to delete a non-existent snippet
|
||||||
|
non_existent_delete = repo.delete(999)
|
||||||
|
assert non_existent_delete is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_all_key_snippets(setup_db):
|
||||||
|
"""Test retrieving all key snippets."""
|
||||||
|
# Set up repository with the in-memory database
|
||||||
|
repo = KeySnippetRepository(db=setup_db)
|
||||||
|
|
||||||
|
# Create some key snippets
|
||||||
|
snippets_data = [
|
||||||
|
{
|
||||||
|
"filepath": "file1.py",
|
||||||
|
"line_number": 10,
|
||||||
|
"snippet": "def function1():",
|
||||||
|
"description": "Function 1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filepath": "file2.py",
|
||||||
|
"line_number": 20,
|
||||||
|
"snippet": "def function2():",
|
||||||
|
"description": "Function 2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filepath": "file3.py",
|
||||||
|
"line_number": 30,
|
||||||
|
"snippet": "def function3():",
|
||||||
|
"description": "Function 3"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for data in snippets_data:
|
||||||
|
repo.create(**data)
|
||||||
|
|
||||||
|
# Retrieve all snippets
|
||||||
|
all_snippets = repo.get_all()
|
||||||
|
|
||||||
|
# Verify we got the correct number of snippets
|
||||||
|
assert len(all_snippets) == len(snippets_data)
|
||||||
|
|
||||||
|
# Verify the content of each snippet
|
||||||
|
for i, snippet in enumerate(all_snippets):
|
||||||
|
assert snippet.filepath == snippets_data[i]["filepath"]
|
||||||
|
assert snippet.line_number == snippets_data[i]["line_number"]
|
||||||
|
assert snippet.snippet == snippets_data[i]["snippet"]
|
||||||
|
assert snippet.description == snippets_data[i]["description"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_snippets_dict(setup_db):
|
||||||
|
"""Test retrieving key snippets as a dictionary."""
|
||||||
|
# Set up repository with the in-memory database
|
||||||
|
repo = KeySnippetRepository(db=setup_db)
|
||||||
|
|
||||||
|
# Create some key snippets
|
||||||
|
snippets = []
|
||||||
|
snippets_data = [
|
||||||
|
{
|
||||||
|
"filepath": "file1.py",
|
||||||
|
"line_number": 10,
|
||||||
|
"snippet": "def function1():",
|
||||||
|
"description": "Function 1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filepath": "file2.py",
|
||||||
|
"line_number": 20,
|
||||||
|
"snippet": "def function2():",
|
||||||
|
"description": "Function 2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filepath": "file3.py",
|
||||||
|
"line_number": 30,
|
||||||
|
"snippet": "def function3():",
|
||||||
|
"description": "Function 3"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for data in snippets_data:
|
||||||
|
snippets.append(repo.create(**data))
|
||||||
|
|
||||||
|
# Retrieve snippets as dictionary
|
||||||
|
snippets_dict = repo.get_snippets_dict()
|
||||||
|
|
||||||
|
# Verify we got the correct number of snippets
|
||||||
|
assert len(snippets_dict) == len(snippets_data)
|
||||||
|
|
||||||
|
# Verify each snippet is in the dictionary with the correct content
|
||||||
|
for i, snippet in enumerate(snippets):
|
||||||
|
assert snippet.id in snippets_dict
|
||||||
|
assert snippets_dict[snippet.id]["filepath"] == snippets_data[i]["filepath"]
|
||||||
|
assert snippets_dict[snippet.id]["line_number"] == snippets_data[i]["line_number"]
|
||||||
|
assert snippets_dict[snippet.id]["snippet"] == snippets_data[i]["snippet"]
|
||||||
|
assert snippets_dict[snippet.id]["description"] == snippets_data[i]["description"]
|
||||||
|
|
@ -54,14 +54,15 @@ def setup_test_model(cleanup_db):
|
||||||
# Initialize the database in memory
|
# Initialize the database in memory
|
||||||
db = init_db(in_memory=True)
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
|
# Initialize the database proxy
|
||||||
|
from ra_aid.database.models import initialize_database
|
||||||
|
initialize_database()
|
||||||
|
|
||||||
# Define a test model class
|
# Define a test model class
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
name = peewee.CharField(max_length=100)
|
name = peewee.CharField(max_length=100)
|
||||||
value = peewee.IntegerField(default=0)
|
value = peewee.IntegerField(default=0)
|
||||||
|
|
||||||
class Meta:
|
|
||||||
database = db
|
|
||||||
|
|
||||||
# Create the test table in a transaction
|
# Create the test table in a transaction
|
||||||
with db.atomic():
|
with db.atomic():
|
||||||
db.create_tables([TestModel], safe=True)
|
db.create_tables([TestModel], safe=True)
|
||||||
|
|
@ -79,14 +80,15 @@ def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
|
||||||
# Initialize the database in memory
|
# Initialize the database in memory
|
||||||
db = init_db(in_memory=True)
|
db = init_db(in_memory=True)
|
||||||
|
|
||||||
# Define a test model that uses this database
|
# Initialize the database proxy
|
||||||
|
from ra_aid.database.models import initialize_database
|
||||||
|
initialize_database()
|
||||||
|
|
||||||
|
# Define a test model that uses the proxy database
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
name = peewee.CharField(max_length=100)
|
name = peewee.CharField(max_length=100)
|
||||||
value = peewee.IntegerField(default=0)
|
value = peewee.IntegerField(default=0)
|
||||||
|
|
||||||
class Meta:
|
|
||||||
database = db
|
|
||||||
|
|
||||||
# Call ensure_tables_created with explicit models
|
# Call ensure_tables_created with explicit models
|
||||||
ensure_tables_created([TestModel])
|
ensure_tables_created([TestModel])
|
||||||
|
|
||||||
|
|
@ -99,9 +101,9 @@ def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
|
||||||
assert count == 1
|
assert count == 1
|
||||||
|
|
||||||
|
|
||||||
@patch("ra_aid.database.utils.get_db")
|
@patch("ra_aid.database.utils.initialize_database")
|
||||||
def test_ensure_tables_created_database_error(
|
def test_ensure_tables_created_database_error(
|
||||||
mock_get_db, setup_test_model, cleanup_db, mock_logger
|
mock_initialize_database, setup_test_model, cleanup_db, mock_logger
|
||||||
):
|
):
|
||||||
"""Test ensure_tables_created handles database errors."""
|
"""Test ensure_tables_created handles database errors."""
|
||||||
# Get the TestModel class from the fixture
|
# Get the TestModel class from the fixture
|
||||||
|
|
@ -113,8 +115,8 @@ def test_ensure_tables_created_database_error(
|
||||||
mock_db.atomic.return_value.__exit__.return_value = None
|
mock_db.atomic.return_value.__exit__.return_value = None
|
||||||
mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error")
|
mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error")
|
||||||
|
|
||||||
# Configure get_db to return our mock
|
# Configure initialize_database to return our mock
|
||||||
mock_get_db.return_value = mock_db
|
mock_initialize_database.return_value = mock_db
|
||||||
|
|
||||||
# Call ensure_tables_created and expect an exception
|
# Call ensure_tables_created and expect an exception
|
||||||
with pytest.raises(peewee.DatabaseError):
|
with pytest.raises(peewee.DatabaseError):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue