key snippets db

This commit is contained in:
AI Christianson 2025-03-02 19:06:51 -05:00
parent be2eb298a5
commit 038e7b886c
17 changed files with 1320 additions and 718 deletions

View File

@ -610,6 +610,9 @@ def main():
memory=planning_memory,
config=config,
)
# Run cleanup tasks before exiting database context
run_cleanup()
except (KeyboardInterrupt, AgentInterrupt):
print()
@ -618,5 +621,17 @@ def main():
sys.exit(0)
def run_cleanup():
"""Run cleanup tasks after main execution."""
try:
# Import the key facts cleaner agent
from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent
# Run the key facts garbage collection agent regardless of the number of facts
run_key_facts_gc_agent()
except Exception as e:
logger.error(f"Failed to run cleanup tasks: {str(e)}")
if __name__ == "__main__":
main()

View File

@ -419,7 +419,7 @@ def run_research_agent(
project_info=formatted_project_info,
new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "",
)
config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {

View File

@ -13,7 +13,7 @@ from ra_aid.database.migrations import (
get_migration_status,
init_migrations,
)
from ra_aid.database.models import BaseModel
from ra_aid.database.models import BaseModel, initialize_database
from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table
__all__ = [
@ -22,6 +22,7 @@ __all__ = [
"close_db",
"DatabaseManager",
"BaseModel",
"initialize_database",
"get_model_count",
"truncate_table",
"ensure_tables_created",

View File

@ -14,6 +14,9 @@ import peewee
from ra_aid.logging_config import get_logger
# Import initialize_database after it's defined in models.py
# We need to do the import inside functions to avoid circular imports
# Create contextvar to hold the database connection
db_var = contextvars.ContextVar("db", default=None)
logger = get_logger(__name__)
@ -34,16 +37,23 @@ class DatabaseManager:
# Or with in-memory database:
with DatabaseManager(in_memory=True) as db:
# Use in-memory database
# Or with custom base directory:
with DatabaseManager(base_dir="/custom/path") as db:
# Use database in custom directory
"""
def __init__(self, in_memory: bool = False):
def __init__(self, in_memory: bool = False, base_dir: Optional[str] = None):
"""
Initialize the DatabaseManager.
Args:
in_memory: Whether to use an in-memory database (default: False)
base_dir: Optional base directory to use instead of current working directory.
If None, uses os.getcwd() (default: None)
"""
self.in_memory = in_memory
self.base_dir = base_dir
def __enter__(self) -> peewee.SqliteDatabase:
"""
@ -52,7 +62,19 @@ class DatabaseManager:
Returns:
peewee.SqliteDatabase: The initialized database connection
"""
return init_db(in_memory=self.in_memory)
db = init_db(in_memory=self.in_memory, base_dir=self.base_dir)
# Initialize the database proxy in models.py
try:
# Import here to avoid circular imports
from ra_aid.database.models import initialize_database
initialize_database()
except ImportError as e:
logger.error(f"Failed to import initialize_database: {str(e)}")
except Exception as e:
logger.error(f"Error initializing database proxy: {str(e)}")
return db
def __exit__(
self,
@ -74,7 +96,7 @@ class DatabaseManager:
return False
def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
def init_db(in_memory: bool = False, base_dir: Optional[str] = None) -> peewee.SqliteDatabase:
"""
Initialize the database connection.
@ -84,6 +106,8 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
Args:
in_memory: Whether to use an in-memory database (default: False)
base_dir: Optional base directory to use instead of current working directory.
If None, uses os.getcwd() (default: None)
Returns:
peewee.SqliteDatabase: The initialized database connection
@ -110,9 +134,9 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
db_path = ":memory:"
logger.debug("Using in-memory SQLite database")
else:
# Get current working directory and create .ra-aid directory if it doesn't exist
cwd = os.getcwd()
logger.debug(f"Current working directory: {cwd}")
# Get base directory (use current working directory if not provided)
cwd = base_dir if base_dir is not None else os.getcwd()
logger.debug(f"Base directory for database: {cwd}")
# Define the .ra-aid directory path
ra_aid_dir_str = os.path.join(cwd, ".ra-aid")
@ -300,13 +324,17 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase:
raise
def get_db() -> peewee.SqliteDatabase:
def get_db(base_dir: Optional[str] = None) -> peewee.SqliteDatabase:
"""
Get the current database connection.
If no connection exists, initializes a new one.
If connection exists but is closed, reopens it.
Args:
base_dir: Optional base directory to use instead of current working directory.
If None, uses os.getcwd() (default: None)
Returns:
peewee.SqliteDatabase: The current database connection
"""
@ -315,7 +343,7 @@ def get_db() -> peewee.SqliteDatabase:
if db is None:
# No database connection exists, initialize one
# Use the default in-memory mode (False)
return init_db(in_memory=False)
return init_db(in_memory=False, base_dir=base_dir)
# Check if connection is closed and reopen if needed
if db.is_closed():
@ -332,7 +360,7 @@ def get_db() -> peewee.SqliteDatabase:
in_memory = hasattr(db, "_is_in_memory") and db._is_in_memory
logger.debug(f"Creating new database connection (in_memory={in_memory})")
# Create a completely new database object, don't reuse the old one
return init_db(in_memory=in_memory)
return init_db(in_memory=in_memory, base_dir=base_dir)
return db

View File

@ -15,6 +15,21 @@ from ra_aid.logging_config import get_logger
T = TypeVar("T", bound="BaseModel")
logger = get_logger(__name__)
# Create a database proxy that will be initialized later
database_proxy = peewee.DatabaseProxy()
def initialize_database():
"""
Initialize the database proxy with a real database connection.
This function should be called before any database operations
to ensure the proxy points to a real database connection.
"""
db = get_db()
database_proxy.initialize(db)
return db
class BaseModel(peewee.Model):
"""
@ -28,7 +43,7 @@ class BaseModel(peewee.Model):
updated_at = peewee.DateTimeField(default=datetime.datetime.now)
class Meta:
database = get_db()
database = database_proxy
def save(self, *args: Any, **kwargs: Any) -> int:
"""
@ -75,4 +90,22 @@ class KeyFact(BaseModel):
# created_at and updated_at are inherited from BaseModel
class Meta:
table_name = "key_fact"
table_name = "key_fact"
class KeySnippet(BaseModel):
"""
Model representing a key code snippet stored in the database.
Key snippets are important code fragments from the project that need to be
referenced later. Each snippet includes its file location, line number,
the code content itself, and an optional description of its significance.
"""
filepath = peewee.TextField()
line_number = peewee.IntegerField()
snippet = peewee.TextField()
description = peewee.TextField(null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
table_name = "key_snippet"

View File

@ -10,7 +10,7 @@ from typing import Dict, List, Optional
import peewee
from ra_aid.database.connection import get_db
from ra_aid.database.models import KeyFact
from ra_aid.database.models import KeyFact, initialize_database
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
@ -29,6 +29,15 @@ class KeyFactRepository:
all_facts = repo.get_all()
"""
def __init__(self, db=None):
"""
Initialize the repository with an optional database connection.
Args:
db: Optional database connection to use. If None, will use initialize_database()
"""
self.db = db
def create(self, content: str) -> KeyFact:
"""
Create a new key fact in the database.
@ -43,7 +52,7 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error creating the fact
"""
try:
db = get_db()
db = self.db if self.db is not None else initialize_database()
fact = KeyFact.create(content=content)
logger.debug(f"Created key fact ID {fact.id}: {content}")
return fact
@ -65,7 +74,7 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
db = get_db()
db = self.db if self.db is not None else initialize_database()
return KeyFact.get_or_none(KeyFact.id == fact_id)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}")
@ -86,7 +95,7 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error updating the fact
"""
try:
db = get_db()
db = self.db if self.db is not None else initialize_database()
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
@ -116,7 +125,7 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error deleting the fact
"""
try:
db = get_db()
db = self.db if self.db is not None else initialize_database()
# First check if the fact exists
fact = self.get(fact_id)
if not fact:
@ -142,7 +151,7 @@ class KeyFactRepository:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
db = get_db()
db = self.db if self.db is not None else initialize_database()
return list(KeyFact.select().order_by(KeyFact.id))
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all key facts: {str(e)}")

View File

@ -0,0 +1,214 @@
"""
Key snippet repository implementation for database access.
This module provides a repository implementation for the KeySnippet model,
following the repository pattern for data access abstraction.
"""
from typing import Dict, List, Optional, Any
import peewee
from ra_aid.database.connection import get_db
from ra_aid.database.models import KeySnippet, initialize_database
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
class KeySnippetRepository:
"""
Repository for managing KeySnippet database operations.
This class provides methods for performing CRUD operations on the KeySnippet model,
abstracting the database access details from the business logic.
Example:
repo = KeySnippetRepository()
snippet = repo.create(
filepath="main.py",
line_number=42,
snippet="def hello_world():",
description="Main function definition"
)
all_snippets = repo.get_all()
"""
def __init__(self, db=None):
"""
Initialize the repository with an optional database connection.
Args:
db: Optional database connection to use. If None, will use initialize_database()
"""
self.db = db
def create(
self, filepath: str, line_number: int, snippet: str, description: Optional[str] = None
) -> KeySnippet:
"""
Create a new key snippet in the database.
Args:
filepath: Path to the source file
line_number: Line number where the snippet starts
snippet: The source code snippet text
description: Optional description of the significance
Returns:
KeySnippet: The newly created key snippet instance
Raises:
peewee.DatabaseError: If there's an error creating the snippet
"""
try:
db = self.db if self.db is not None else initialize_database()
key_snippet = KeySnippet.create(
filepath=filepath,
line_number=line_number,
snippet=snippet,
description=description
)
logger.debug(f"Created key snippet ID {key_snippet.id}: {filepath}:{line_number}")
return key_snippet
except peewee.DatabaseError as e:
logger.error(f"Failed to create key snippet: {str(e)}")
raise
def get(self, snippet_id: int) -> Optional[KeySnippet]:
"""
Retrieve a key snippet by its ID.
Args:
snippet_id: The ID of the key snippet to retrieve
Returns:
Optional[KeySnippet]: The key snippet instance if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
db = self.db if self.db is not None else initialize_database()
return KeySnippet.get_or_none(KeySnippet.id == snippet_id)
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}")
raise
def update(
self,
snippet_id: int,
filepath: str,
line_number: int,
snippet: str,
description: Optional[str] = None
) -> Optional[KeySnippet]:
"""
Update an existing key snippet.
Args:
snippet_id: The ID of the key snippet to update
filepath: Path to the source file
line_number: Line number where the snippet starts
snippet: The source code snippet text
description: Optional description of the significance
Returns:
Optional[KeySnippet]: The updated key snippet if found, None otherwise
Raises:
peewee.DatabaseError: If there's an error updating the snippet
"""
try:
db = self.db if self.db is not None else initialize_database()
# First check if the snippet exists
key_snippet = self.get(snippet_id)
if not key_snippet:
logger.warning(f"Attempted to update non-existent key snippet {snippet_id}")
return None
# Update the snippet
key_snippet.filepath = filepath
key_snippet.line_number = line_number
key_snippet.snippet = snippet
key_snippet.description = description
key_snippet.save()
logger.debug(f"Updated key snippet ID {snippet_id}: {filepath}:{line_number}")
return key_snippet
except peewee.DatabaseError as e:
logger.error(f"Failed to update key snippet {snippet_id}: {str(e)}")
raise
def delete(self, snippet_id: int) -> bool:
"""
Delete a key snippet by its ID.
Args:
snippet_id: The ID of the key snippet to delete
Returns:
bool: True if the snippet was deleted, False if it wasn't found
Raises:
peewee.DatabaseError: If there's an error deleting the snippet
"""
try:
db = self.db if self.db is not None else initialize_database()
# First check if the snippet exists
key_snippet = self.get(snippet_id)
if not key_snippet:
logger.warning(f"Attempted to delete non-existent key snippet {snippet_id}")
return False
# Delete the snippet
key_snippet.delete_instance()
logger.debug(f"Deleted key snippet ID {snippet_id}")
return True
except peewee.DatabaseError as e:
logger.error(f"Failed to delete key snippet {snippet_id}: {str(e)}")
raise
def get_all(self) -> List[KeySnippet]:
"""
Retrieve all key snippets from the database.
Returns:
List[KeySnippet]: List of all key snippet instances
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
db = self.db if self.db is not None else initialize_database()
return list(KeySnippet.select().order_by(KeySnippet.id))
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch all key snippets: {str(e)}")
raise
def get_snippets_dict(self) -> Dict[int, Dict[str, Any]]:
"""
Retrieve all key snippets as a dictionary mapping IDs to snippet information.
This method is useful for compatibility with the existing memory format.
Returns:
Dict[int, Dict[str, Any]]: Dictionary with snippet IDs as keys and
snippet information as values
Raises:
peewee.DatabaseError: If there's an error accessing the database
"""
try:
snippets = self.get_all()
return {
snippet.id: {
"filepath": snippet.filepath,
"line_number": snippet.line_number,
"snippet": snippet.snippet,
"description": snippet.description
}
for snippet in snippets
}
except peewee.DatabaseError as e:
logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}")
raise

View File

@ -1,5 +1,8 @@
"""
Tests for the database connection module.
This file tests the database connection functionality using pytest's fixtures
for proper test isolation.
"""
import os
@ -21,642 +24,346 @@ from ra_aid.database.connection import (
@pytest.fixture
def cleanup_db():
"""
Fixture to clean up database connections after tests.
Fixture to clean up database connections between tests.
This ensures that we don't leak database connections between tests
and that the db_var contextvar is reset.
"""
# Run the test
yield
# Clean up after the test
db = db_var.get()
if db is not None:
# Clean up attributes we may have added
if hasattr(db, "_is_in_memory"):
delattr(db, "_is_in_memory")
if hasattr(db, "_message_shown"):
delattr(db, "_message_shown")
# Close the connection if it's open
if not db.is_closed():
db.close()
# Reset the contextvar
db_var.set(None)
@pytest.fixture
def setup_in_memory_db():
def db_path_mock(tmp_path, monkeypatch):
"""
Fixture to set up an in-memory database for testing.
Fixture to mock os.getcwd() to return a temporary directory path.
This ensures that all database operations use the temporary directory
and never touch the actual current working directory.
"""
# Initialize in-memory database
db = init_db(in_memory=True)
# Run the test
yield db
# Clean up
if not db.is_closed():
db.close()
db_var.set(None)
def test_init_db_creates_directory(cleanup_db, tmp_path):
"""
Test that init_db creates the .ra-aid directory if it doesn't exist.
"""
# Get and print the original working directory
original_cwd = os.getcwd()
print(f"Original working directory: {original_cwd}")
# Convert tmp_path to string for consistent handling
tmp_path_str = str(tmp_path.absolute())
print(f"Temporary directory path: {tmp_path_str}")
# Change to the temporary directory
os.chdir(tmp_path_str)
current_cwd = os.getcwd()
print(f"Changed working directory to: {current_cwd}")
assert (
current_cwd == tmp_path_str
), f"Failed to change directory: {current_cwd} != {tmp_path_str}"
# Create the .ra-aid directory manually to ensure it exists
ra_aid_path_str = os.path.join(current_cwd, ".ra-aid")
print(f"Creating .ra-aid directory at: {ra_aid_path_str}")
os.makedirs(ra_aid_path_str, exist_ok=True)
# Verify the directory was created
assert os.path.exists(
ra_aid_path_str
), f".ra-aid directory not found at {ra_aid_path_str}"
assert os.path.isdir(
ra_aid_path_str
), f"{ra_aid_path_str} exists but is not a directory"
# Create a test file to verify write permissions
test_file_path = os.path.join(ra_aid_path_str, "test_write.txt")
print(f"Creating test file to verify write permissions: {test_file_path}")
with open(test_file_path, "w") as f:
f.write("Test write permissions")
# Verify the test file was created
assert os.path.exists(test_file_path), f"Test file not created at {test_file_path}"
# Create an empty database file to ensure it exists before init_db
db_file_str = os.path.join(ra_aid_path_str, "pk.db")
print(f"Creating empty database file at: {db_file_str}")
with open(db_file_str, "w") as f:
f.write("") # Create empty file
# Verify the database file was created
assert os.path.exists(
db_file_str
), f"Empty database file not created at {db_file_str}"
print(f"Empty database file size: {os.path.getsize(db_file_str)} bytes")
# Get directory permissions for debugging
dir_perms = oct(os.stat(ra_aid_path_str).st_mode)[-3:]
print(f"Directory permissions: {dir_perms}")
# Initialize the database
print("Calling init_db()")
db = init_db()
print("init_db() returned successfully")
# List contents of the current directory for debugging
print(f"Contents of current directory: {os.listdir(current_cwd)}")
# List contents of the .ra-aid directory for debugging
print(f"Contents of .ra-aid directory: {os.listdir(ra_aid_path_str)}")
# Check that the database file exists using os.path
assert os.path.exists(db_file_str), f"Database file not found at {db_file_str}"
assert os.path.isfile(db_file_str), f"{db_file_str} exists but is not a file"
print(f"Final database file size: {os.path.getsize(db_file_str)} bytes")
def test_init_db_creates_database_file(cleanup_db, tmp_path):
"""
Test that init_db creates the database file.
"""
# Change to the temporary directory
os.chdir(tmp_path)
# Initialize the database
init_db()
# Check that the database file was created
assert (tmp_path / ".ra-aid" / "pk.db").exists()
assert (tmp_path / ".ra-aid" / "pk.db").is_file()
def test_init_db_returns_database_connection(cleanup_db):
"""
Test that init_db returns a database connection.
"""
# Initialize the database
db = init_db()
# Check that the database connection is returned
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
def test_init_db_with_in_memory_mode(cleanup_db):
"""
Test that init_db with in_memory=True creates an in-memory database.
"""
# Initialize the database in in-memory mode
db = init_db(in_memory=True)
# Check that the database connection is returned
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
def test_in_memory_mode_no_directory_created(cleanup_db, tmp_path):
"""
Test that when using in-memory mode, no directory is created.
"""
# Change to the temporary directory
os.chdir(tmp_path)
# Initialize the database in in-memory mode
init_db(in_memory=True)
# Check that the .ra-aid directory was not created
# (Note: it might be created by other tests, so we can't assert it doesn't exist)
# Instead, check that the database file was not created
assert not (tmp_path / ".ra-aid" / "pk.db").exists()
def test_init_db_returns_existing_connection(cleanup_db):
"""
Test that init_db returns the existing connection if one exists.
"""
# Initialize the database
db1 = init_db()
# Initialize the database again
db2 = init_db()
# Check that the same connection is returned
assert db1 is db2
def test_init_db_reopens_closed_connection(cleanup_db):
"""
Test that init_db reopens a closed connection.
"""
# Initialize the database
db1 = init_db()
# Close the connection
db1.close()
# Initialize the database again
db2 = init_db()
# Check that the same connection is returned and it's open
assert db1 is db2
assert not db1.is_closed()
def test_get_db_initializes_connection(cleanup_db):
"""
Test that get_db initializes a connection if none exists.
"""
# Get the database connection
db = get_db()
# Check that a connection was initialized
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
def test_get_db_returns_existing_connection(cleanup_db):
"""
Test that get_db returns the existing connection if one exists.
"""
# Initialize the database
db1 = init_db()
# Get the database connection
db2 = get_db()
# Check that the same connection is returned
assert db1 is db2
def test_get_db_reopens_closed_connection(cleanup_db):
"""
Test that get_db reopens a closed connection.
"""
# Initialize the database
db = init_db()
# Close the connection
db.close()
# Get the database connection
db2 = get_db()
# Check that the same connection is returned and it's open
assert db is db2
assert not db.is_closed()
def test_get_db_handles_reopen_error(cleanup_db, monkeypatch):
"""
Test that get_db handles errors when reopening a connection.
"""
# Initialize the database
db = init_db()
# Close the connection
db.close()
# Create a patched version of the connect method that raises an error
original_connect = peewee.SqliteDatabase.connect
def mock_connect(self, *args, **kwargs):
if self is db: # Only raise for the specific db instance
raise peewee.OperationalError("Test error")
return original_connect(self, *args, **kwargs)
# Apply the patch
monkeypatch.setattr(peewee.SqliteDatabase, "connect", mock_connect)
# Get the database connection
db2 = get_db()
# Check that a new connection was initialized
assert db is not db2
assert not db2.is_closed()
def test_close_db_closes_connection(cleanup_db):
"""
Test that close_db closes the connection.
"""
# Initialize the database
db = init_db()
# Close the connection
close_db()
# Check that the connection is closed
assert db.is_closed()
def test_close_db_handles_no_connection():
"""
Test that close_db handles the case where no connection exists.
"""
# Reset the contextvar
db_var.set(None)
# Close the connection (should not raise an error)
close_db()
def test_close_db_handles_already_closed_connection(cleanup_db):
"""
Test that close_db handles the case where the connection is already closed.
"""
# Initialize the database
db = init_db()
# Close the connection
db.close()
# Close the connection again (should not raise an error)
close_db()
@patch("ra_aid.database.connection.peewee.SqliteDatabase.close")
def test_close_db_handles_error(mock_close, cleanup_db):
"""
Test that close_db handles errors when closing the connection.
"""
# Initialize the database
init_db()
# Make close raise an error
mock_close.side_effect = peewee.DatabaseError("Test error")
# Close the connection (should not raise an error)
close_db()
def test_database_manager_context_manager(cleanup_db):
"""
Test that DatabaseManager works as a context manager.
"""
# Use the context manager
with DatabaseManager() as db:
# Check that a connection was initialized
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
# Store the connection for later
db_in_context = db
# Check that the connection is closed after exiting the context
assert db_in_context.is_closed()
def test_database_manager_with_in_memory_mode(cleanup_db):
"""
Test that DatabaseManager with in_memory=True creates an in-memory database.
"""
# Use the context manager with in_memory=True
with DatabaseManager(in_memory=True) as db:
# Check that a connection was initialized
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
def test_init_db_shows_message_only_once(cleanup_db, caplog):
"""
Test that init_db only shows the initialization message once.
"""
# Initialize the database
init_db(in_memory=True)
# Clear the log
caplog.clear()
# Initialize the database again
init_db(in_memory=True)
# Check that no message was logged
assert "database connection initialized" not in caplog.text.lower()
def test_init_db_sets_is_in_memory_attribute(cleanup_db):
"""
Test that init_db sets the _is_in_memory attribute.
"""
# Initialize the database with in_memory=False
db = init_db(in_memory=False)
# Check that the _is_in_memory attribute is set to False
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Reset the contextvar
db_var.set(None)
# Initialize the database with in_memory=True
db = init_db(in_memory=True)
# Check that the _is_in_memory attribute is set to True
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
"""
Tests for the database connection module.
"""
import pytest
@pytest.fixture
def cleanup_db():
"""
Fixture to clean up database connections and files between tests.
This fixture:
1. Closes any open database connection
2. Resets the contextvar
3. Cleans up the .ra-aid directory
"""
# Store the current working directory
original_cwd = os.getcwd()
# Run the test
yield
# Clean up after the test
try:
# Close any open database connection
close_db()
# Reset the contextvar
db_var.set(None)
# Clean up the .ra-aid directory if it exists
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
ra_aid_dir_str = str(ra_aid_dir.absolute())
# Check using both methods
path_exists = ra_aid_dir.exists()
os_exists = os.path.exists(ra_aid_dir_str)
print(f"Cleanup check: Path.exists={path_exists}, os.path.exists={os_exists}")
if os_exists:
# Only remove the database file, not the entire directory
db_file = os.path.join(ra_aid_dir_str, "pk.db")
if os.path.exists(db_file):
os.unlink(db_file)
# Remove WAL and SHM files if they exist
wal_file = os.path.join(ra_aid_dir_str, "pk.db-wal")
if os.path.exists(wal_file):
os.unlink(wal_file)
shm_file = os.path.join(ra_aid_dir_str, "pk.db-shm")
if os.path.exists(shm_file):
os.unlink(shm_file)
# List remaining contents for debugging
if os.path.exists(ra_aid_dir_str):
print(f"Directory contents after cleanup: {os.listdir(ra_aid_dir_str)}")
except Exception as e:
# Log but don't fail if cleanup has issues
print(f"Cleanup error (non-fatal): {str(e)}")
# Make sure we're back in the original directory
# Create the .ra-aid directory in the temporary path
ra_aid_dir = tmp_path / ".ra-aid"
ra_aid_dir.mkdir(exist_ok=True)
# Mock os.getcwd() to return the temporary directory
monkeypatch.setattr(os, "getcwd", lambda: tmp_path_str)
yield tmp_path
# Ensure we're back to the original directory after the test
os.chdir(original_cwd)
class TestInitDb:
"""Tests for the init_db function."""
def test_init_db_default(self, cleanup_db):
"""Test init_db with default parameters."""
# Get the absolute path of the current working directory
cwd = os.getcwd()
print(f"Current working directory: {cwd}")
# Initialize the database
db = init_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Verify the database file was created using both Path and os.path methods
ra_aid_dir = Path(cwd) / ".ra-aid"
ra_aid_dir_str = str(ra_aid_dir.absolute())
# Check directory existence using both methods
path_exists = ra_aid_dir.exists()
os_exists = os.path.exists(ra_aid_dir_str)
print(f"Directory check: Path.exists={path_exists}, os.path.exists={os_exists}")
# List the contents of the current directory
print(f"Contents of {cwd}: {os.listdir(cwd)}")
# If the directory exists, list its contents
if os_exists:
print(f"Contents of {ra_aid_dir_str}: {os.listdir(ra_aid_dir_str)}")
# Use os.path for assertions to be more reliable
assert os.path.exists(
ra_aid_dir_str
), f"Directory {ra_aid_dir_str} does not exist"
assert os.path.isdir(ra_aid_dir_str), f"{ra_aid_dir_str} is not a directory"
db_file = os.path.join(ra_aid_dir_str, "pk.db")
assert os.path.exists(db_file), f"Database file {db_file} does not exist"
assert os.path.isfile(db_file), f"{db_file} is not a file"
def test_init_db_in_memory(self, cleanup_db):
"""Test init_db with in_memory=True."""
# Reset the contextvar to ensure a fresh start
db_var.set(None)
db = init_db(in_memory=True)
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
def test_init_db_reuses_connection(self, cleanup_db):
"""Test that init_db reuses an existing connection."""
db1 = init_db()
db2 = init_db()
assert db1 is db2
def test_init_db_reopens_closed_connection(self, cleanup_db):
"""Test that init_db reopens a closed connection."""
db1 = init_db()
db1.close()
assert db1.is_closed()
db2 = init_db()
assert db1 is db2
assert not db1.is_closed()
class TestGetDb:
"""Tests for the get_db function."""
def test_get_db_creates_connection(self, cleanup_db):
"""Test that get_db creates a new connection if none exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
db = get_db()
def test_init_db_creates_directory(self, cleanup_db, db_path_mock):
"""Test that init_db creates the .ra-aid directory if it doesn't exist."""
# Remove the .ra-aid directory to test creation
ra_aid_dir = db_path_mock / ".ra-aid"
if ra_aid_dir.exists():
for item in ra_aid_dir.iterdir():
if item.is_file():
item.unlink()
ra_aid_dir.rmdir()
# Initialize the database
db = init_db()
# Check that the directory was created
assert ra_aid_dir.exists()
assert ra_aid_dir.is_dir()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
def test_get_db_reuses_connection(self, cleanup_db):
"""Test that get_db reuses an existing connection."""
db1 = init_db()
db2 = get_db()
def test_init_db_creates_database_file(self, cleanup_db, db_path_mock):
"""Test that init_db creates the database file."""
# Initialize the database
init_db()
# Check that the database file was created
assert (db_path_mock / ".ra-aid" / "pk.db").exists()
assert (db_path_mock / ".ra-aid" / "pk.db").is_file()
def test_init_db_reuses_connection(self, cleanup_db):
"""Test that init_db reuses an existing connection."""
# Reset the contextvar to ensure a fresh start
db_var.set(None)
# Use in_memory=True for this test to avoid touching the filesystem
db1 = init_db(in_memory=True)
db2 = init_db(in_memory=True)
assert db1 is db2
def test_get_db_reopens_closed_connection(self, cleanup_db):
"""Test that get_db reopens a closed connection."""
db1 = init_db()
def test_init_db_reopens_closed_connection(self, cleanup_db):
"""Test that init_db reopens a closed connection."""
# Reset the contextvar to ensure a fresh start
db_var.set(None)
# Use in_memory=True for this test to avoid touching the filesystem
db1 = init_db(in_memory=True)
db1.close()
assert db1.is_closed()
db2 = init_db(in_memory=True)
assert db1 is db2
assert not db1.is_closed()
def test_in_memory_mode_no_directory_created(self, cleanup_db, db_path_mock):
"""Test that when using in_memory mode, no database file is created."""
# Initialize the database in in-memory mode
init_db(in_memory=True)
# Check that the database file was not created
assert not (db_path_mock / ".ra-aid" / "pk.db").exists()
def test_init_db_sets_is_in_memory_attribute(self, cleanup_db):
"""Test that init_db sets the _is_in_memory attribute."""
# Test with in_memory=True
db = init_db(in_memory=True)
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
# Reset the contextvar
db_var.set(None)
# Test with in_memory=False, but use a mocked directory
with patch("os.getcwd") as mock_getcwd:
temp_dir = Path("/tmp/testdb")
mock_getcwd.return_value = str(temp_dir)
# Mock os.path.exists and os.makedirs to avoid filesystem operations
with patch("os.path.exists", return_value=True):
with patch("os.makedirs"):
with patch("os.path.isdir", return_value=True):
with patch.object(peewee.SqliteDatabase, "connect"):
with patch.object(peewee.SqliteDatabase, "execute_sql"):
db = init_db(in_memory=False)
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
class TestGetDb:
"""Tests for the get_db function."""
def test_get_db_initializes_connection(self, cleanup_db):
"""Test that get_db initializes a connection if none exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
# Use a patch to avoid touching the filesystem
with patch("ra_aid.database.connection.init_db") as mock_init_db:
mock_db = peewee.SqliteDatabase(":memory:")
mock_db._is_in_memory = False
mock_init_db.return_value = mock_db
db = get_db()
mock_init_db.assert_called_once_with(in_memory=False, base_dir=None)
assert db is mock_db
def test_get_db_returns_existing_connection(self, cleanup_db):
"""Test that get_db returns the existing connection if one exists."""
# Reset the contextvar to ensure a fresh start
db_var.set(None)
# Use in_memory=True for this test to avoid touching the filesystem
db1 = init_db(in_memory=True)
db2 = get_db()
assert db1 is db2
def test_get_db_reopens_closed_connection(self, cleanup_db):
"""Test that get_db reopens a closed connection."""
# Reset the contextvar to ensure a fresh start
db_var.set(None)
# Use in_memory=True for this test to avoid touching the filesystem
db1 = init_db(in_memory=True)
db1.close()
assert db1.is_closed()
db2 = get_db()
assert db1 is db2
assert not db1.is_closed()
def test_get_db_handles_reopen_error(self, cleanup_db, monkeypatch):
"""Test that get_db handles errors when reopening a connection."""
# Reset the contextvar to ensure a fresh start
db_var.set(None)
# Use in_memory=True for this test to avoid touching the filesystem
db = init_db(in_memory=True)
# Close the connection
db.close()
# Create a patched version of the connect method that raises an error
original_connect = peewee.SqliteDatabase.connect
def mock_connect(self, *args, **kwargs):
if self is db: # Only raise for the specific db instance
raise peewee.OperationalError("Test error")
return original_connect(self, *args, **kwargs)
# Apply the patch
monkeypatch.setattr(peewee.SqliteDatabase, "connect", mock_connect)
# Get the database connection - this should create a new one
db2 = get_db()
# Check that a new connection was initialized
assert db is not db2
assert not db2.is_closed()
assert hasattr(db2, "_is_in_memory")
assert db2._is_in_memory is True # Should preserve the in_memory setting
class TestCloseDb:
"""Tests for the close_db function."""
def test_close_db(self, cleanup_db):
"""Test that close_db closes an open connection."""
db = init_db()
assert not db.is_closed()
def test_close_db_closes_connection(self, cleanup_db):
"""Test that close_db closes the connection."""
# Use in_memory=True for this test to avoid touching the filesystem
db = init_db(in_memory=True)
# Close the connection
close_db()
# Check that the connection is closed
assert db.is_closed()
def test_close_db_no_connection(self, cleanup_db):
def test_close_db_handles_no_connection(self):
"""Test that close_db handles the case where no connection exists."""
# Reset the contextvar to ensure no connection exists
# Reset the contextvar
db_var.set(None)
# This should not raise an exception
# Close the connection (should not raise an error)
close_db()
def test_close_db_already_closed(self, cleanup_db):
def test_close_db_handles_already_closed_connection(self, cleanup_db):
"""Test that close_db handles the case where the connection is already closed."""
db = init_db()
# Use in_memory=True for this test to avoid touching the filesystem
db = init_db(in_memory=True)
# Close the connection
db.close()
assert db.is_closed()
# This should not raise an exception
# Close the connection again (should not raise an error)
close_db()
@patch("ra_aid.database.connection.peewee.SqliteDatabase.close")
def test_close_db_handles_error(self, mock_close, cleanup_db):
"""Test that close_db handles errors when closing the connection."""
# Use in_memory=True for this test to avoid touching the filesystem
init_db(in_memory=True)
# Make close raise an error
mock_close.side_effect = peewee.DatabaseError("Test error")
# Close the connection (should not raise an error)
close_db()
class TestDatabaseManager:
"""Tests for the DatabaseManager class."""
def test_database_manager_default(self, cleanup_db):
"""Test DatabaseManager with default parameters."""
with DatabaseManager() as db:
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
assert ra_aid_dir.exists()
assert (ra_aid_dir / "pk.db").exists()
# Verify the connection is closed after exiting the context
assert db.is_closed()
def test_database_manager_in_memory(self, cleanup_db):
"""Test DatabaseManager with in_memory=True."""
def test_database_manager_context_manager_in_memory(self, cleanup_db):
"""Test that DatabaseManager works as a context manager with in_memory=True."""
# Use in_memory=True for this test to avoid touching the filesystem
with DatabaseManager(in_memory=True) as db:
# Check that a connection was initialized
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
# Verify the connection is closed after exiting the context
assert db.is_closed()
# Store the connection for later
db_in_context = db
# Check that the connection is closed after exiting the context
assert db_in_context.is_closed()
def test_database_manager_context_manager_physical_file(self, cleanup_db, db_path_mock):
"""Test that DatabaseManager works as a context manager with a physical file."""
with DatabaseManager(in_memory=False) as db:
# Check that a connection was initialized
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Check that the database file was created
assert (db_path_mock / ".ra-aid" / "pk.db").exists()
assert (db_path_mock / ".ra-aid" / "pk.db").is_file()
# Store the connection for later
db_in_context = db
# Check that the connection is closed after exiting the context
assert db_in_context.is_closed()
def test_database_manager_exception_handling(self, cleanup_db):
"""Test that DatabaseManager properly handles exceptions."""
# Use in_memory=True for this test to avoid touching the filesystem
try:
with DatabaseManager() as db:
with DatabaseManager(in_memory=True) as db:
assert not db.is_closed()
raise ValueError("Test exception")
except ValueError:
# The exception should be propagated
pass
# Verify the connection is closed even if an exception occurred
assert db.is_closed()
def test_init_db_shows_message_only_once(cleanup_db, caplog):
"""Test that init_db only shows the initialization message once."""
# Reset the contextvar to ensure a fresh start
db_var.set(None)
# Use in_memory=True for this test to avoid touching the filesystem
init_db(in_memory=True)
# Clear the log
caplog.clear()
# Initialize the database again
init_db(in_memory=True)
# Check that no message was logged
assert "database connection initialized" not in caplog.text.lower()

View File

@ -10,7 +10,7 @@ from typing import List, Type
import peewee
from ra_aid.database.connection import get_db
from ra_aid.database.models import BaseModel
from ra_aid.database.models import BaseModel, initialize_database
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
@ -26,7 +26,7 @@ def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None:
Args:
models: Optional list of model classes to create tables for
"""
db = get_db()
db = initialize_database()
if models is None:
# If no models are specified, try to discover them
@ -88,7 +88,7 @@ def truncate_table(model_class: Type[BaseModel]) -> None:
Args:
model_class: The model class to truncate
"""
db = get_db()
db = initialize_database()
try:
with db.atomic():
model_class.delete().execute()

View File

@ -0,0 +1,57 @@
"""Peewee migrations -- 003_20250302_163752_add_key_snippet_model.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class KeySnippet(pw.Model):
id = pw.AutoField()
created_at = pw.DateTimeField()
updated_at = pw.DateTimeField()
filepath = pw.TextField()
line_number = pw.IntegerField()
snippet = pw.TextField()
description = pw.TextField(null=True)
class Meta:
table_name = "key_snippet"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model('key_snippet')

View File

@ -0,0 +1,86 @@
"""
Key snippets model formatter module.
This module provides utility functions for formatting key snippets from database models
into consistent markdown styling for display or output purposes.
"""
from typing import Dict, Optional
def format_key_snippet(snippet_id: int, filepath: str, line_number: int, snippet: str, description: Optional[str] = None) -> str:
"""
Format a single key snippet with markdown formatting.
Args:
snippet_id: The identifier of the snippet
filepath: Path to the source file
line_number: Line number where the snippet starts
snippet: The source code snippet text
description: Optional description of the significance
Returns:
str: Formatted key snippet as markdown
Example:
>>> format_key_snippet(1, "src/main.py", 10, "def hello():\\n return 'world'", "Main function")
'## 📝 Code Snippet #1\\n\\n**Source Location**:\\n- File: `src/main.py`\\n- Line: `10`\\n\\n**Code**:\\n```python\\ndef hello():\\n return \\'world\\'\\n```\\n\\n**Description**:\\nMain function'
"""
if not snippet:
return ""
formatted_snippet = f"## 📝 Code Snippet #{snippet_id}\n\n"
formatted_snippet += f"**Source Location**:\n"
formatted_snippet += f"- File: `{filepath}`\n"
formatted_snippet += f"- Line: `{line_number}`\n\n"
formatted_snippet += f"**Code**:\n```python\n{snippet}\n```\n"
if description:
formatted_snippet += f"\n**Description**:\n{description}"
return formatted_snippet
def format_key_snippets_dict(snippets_dict: Dict[int, Dict]) -> str:
"""
Format a dictionary of key snippets with consistent markdown formatting.
Args:
snippets_dict: Dictionary mapping snippet IDs to snippet information dictionaries.
Each snippet dictionary should contain: filepath, line_number, snippet,
and optionally description.
Returns:
str: Formatted key snippets as markdown with proper spacing and headings
Example:
>>> snippets = {
... 1: {
... "filepath": "src/main.py",
... "line_number": 10,
... "snippet": "def hello():\\n return 'world'",
... "description": "Main function"
... }
... }
>>> format_key_snippets_dict(snippets)
'## 📝 Code Snippet #1\\n\\n**Source Location**:\\n- File: `src/main.py`\\n- Line: `10`\\n\\n**Code**:\\n```python\\ndef hello():\\n return \\'world\\'\\n```\\n\\n**Description**:\\nMain function'
"""
if not snippets_dict:
return ""
# Sort by ID for consistent output and format as markdown sections
snippets = []
for snippet_id, snippet_info in sorted(snippets_dict.items()):
snippets.extend([
format_key_snippet(
snippet_id,
snippet_info.get("filepath", ""),
snippet_info.get("line_number", 0),
snippet_info.get("snippet", ""),
snippet_info.get("description", None)
),
"" # Empty line between snippets
])
# Join all snippets and remove trailing newline
return "\n".join(snippets).rstrip()

View File

@ -18,6 +18,8 @@ from ra_aid.agent_context import (
mark_task_completed,
)
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
from ra_aid.model_formatters import key_snippets_formatter
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
@ -40,6 +42,9 @@ console = Console()
# Initialize repository for key facts
key_fact_repository = KeyFactRepository()
# Initialize repository for key snippets
key_snippet_repository = KeySnippetRepository()
# Global memory store
_global_memory: Dict[str, Any] = {
"research_notes": [],
@ -204,11 +209,19 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
# Add filepath to related files
emit_related_files.invoke({"files": [snippet_info["filepath"]]})
# Get and increment snippet ID
snippet_id = _global_memory["key_snippet_id_counter"]
_global_memory["key_snippet_id_counter"] += 1
# Store snippet info
# Create a new key snippet in the database
key_snippet = key_snippet_repository.create(
filepath=snippet_info["filepath"],
line_number=snippet_info["line_number"],
snippet=snippet_info["snippet"],
description=snippet_info["description"],
)
# For backward compatibility, also store in global memory
if "key_snippets" not in _global_memory:
_global_memory["key_snippets"] = {}
snippet_id = key_snippet.id
_global_memory["key_snippets"][snippet_id] = snippet_info
# Format display text as markdown
@ -248,16 +261,27 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
"""
results = []
for snippet_id in snippet_ids:
# Try to delete from database first
success = key_snippet_repository.delete(snippet_id)
# For backward compatibility, also delete from global memory
if snippet_id in _global_memory["key_snippets"]:
# Delete the snippet
deleted_snippet = _global_memory["key_snippets"].pop(snippet_id)
success_msg = f"Successfully deleted snippet #{snippet_id} from {deleted_snippet['filepath']}"
console.print(
Panel(
Markdown(success_msg), title="Snippet Deleted", border_style="green"
)
filepath = deleted_snippet['filepath']
else:
# If not in memory but successful database delete, use generic message
if success:
filepath = "database"
else:
continue # Skip if not found in either place
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
console.print(
Panel(
Markdown(success_msg), title="Snippet Deleted", border_style="green"
)
results.append(success_msg)
)
results.append(success_msg)
log_work_event(f"Deleted snippets {snippet_ids}.")
return "Snippets deleted."
@ -580,8 +604,8 @@ def get_memory_value(key: str) -> str:
"""
Get a value from global memory.
Note: Key facts are now handled by KeyFactRepository and the key_facts_formatter module,
not through this function.
Note: Key facts and key snippets are now handled by their respective repositories
and formatter modules, but this function maintains backward compatibility.
Different memory types return different formats:
- key_snippets: Returns formatted snippets with file path, line number and content
@ -596,29 +620,62 @@ def get_memory_value(key: str) -> str:
- For other types: One value per line
"""
if key == "key_snippets":
values = _global_memory.get(key, {})
if not values:
return ""
# Format each snippet with file info and content using markdown
snippets = []
for k, v in sorted(values.items()):
snippet_text = [
f"## 📝 Code Snippet #{k}",
"", # Empty line for better markdown spacing
"**Source Location**:",
f"- File: `{v['filepath']}`",
f"- Line: `{v['line_number']}`",
"", # Empty line before code block
"**Code**:",
"```python",
v["snippet"].rstrip(), # Remove trailing whitespace
"```",
]
if v["description"]:
# Add empty line and description
snippet_text.extend(["", "**Description**:", v["description"]])
snippets.append("\n".join(snippet_text))
return "\n\n".join(snippets)
try:
# Try to get snippets from repository first
snippets_dict = key_snippet_repository.get_snippets_dict()
if snippets_dict:
return key_snippets_formatter.format_key_snippets_dict(snippets_dict)
# Fallback to global memory for backward compatibility
values = _global_memory.get(key, {})
if not values:
return ""
# Format each snippet with file info and content using markdown
snippets = []
for k, v in sorted(values.items()):
snippet_text = [
f"## 📝 Code Snippet #{k}",
"", # Empty line for better markdown spacing
"**Source Location**:",
f"- File: `{v['filepath']}`",
f"- Line: `{v['line_number']}`",
"", # Empty line before code block
"**Code**:",
"```python",
v["snippet"].rstrip(), # Remove trailing whitespace
"```",
]
if v["description"]:
# Add empty line and description
snippet_text.extend(["", "**Description**:", v["description"]])
snippets.append("\n".join(snippet_text))
return "\n\n".join(snippets)
except Exception as e:
logger.error(f"Error retrieving key snippets: {str(e)}")
# If there's an error with the repository, fall back to global memory
values = _global_memory.get(key, {})
if not values:
return ""
# (Same formatting code as above)
snippets = []
for k, v in sorted(values.items()):
snippet_text = [
f"## 📝 Code Snippet #{k}",
"", # Empty line for better markdown spacing
"**Source Location**:",
f"- File: `{v['filepath']}`",
f"- Line: `{v['line_number']}`",
"", # Empty line before code block
"**Code**:",
"```python",
v["snippet"].rstrip(), # Remove trailing whitespace
"```",
]
if v["description"]:
# Add empty line and description
snippet_text.extend(["", "**Description**:", v["description"]])
snippets.append("\n".join(snippet_text))
return "\n\n".join(snippets)
if key == "work_log":
values = _global_memory.get(key, [])

45
tests/conftest.py Normal file
View File

@ -0,0 +1,45 @@
"""
Global pytest fixtures for RA-AID tests.
This module provides global fixtures that are automatically applied to all tests,
ensuring consistent test environments and proper isolation.
"""
import os
from pathlib import Path
import pytest
@pytest.fixture(autouse=True)
def isolated_db_environment(tmp_path, monkeypatch):
"""
Fixture to ensure all database operations during tests use a temporary directory.
This fixture automatically applies to all tests. It mocks os.getcwd() to return
a temporary directory path, ensuring that database operations never touch the
actual .ra-aid directory in the current working directory.
Args:
tmp_path: Pytest fixture that provides a temporary directory for the test
monkeypatch: Pytest fixture for modifying environment and functions
"""
# Store the original current working directory
original_cwd = os.getcwd()
# Get the absolute path of the temporary directory
tmp_path_str = str(tmp_path.absolute())
# Create the .ra-aid directory in the temporary path
ra_aid_dir = tmp_path / ".ra-aid"
ra_aid_dir.mkdir(exist_ok=True)
# Mock os.getcwd() to return the temporary directory path
monkeypatch.setattr(os, "getcwd", lambda: tmp_path_str)
# Run the test
yield tmp_path
# No need to restore os.getcwd() as monkeypatch does this automatically
# No need to assert original_cwd is restored, as it's just the function that's mocked,
# not the actual working directory

View File

@ -1,10 +1,18 @@
"""
Tests for the database connection module.
NOTE: These tests have been updated to minimize file system interactions by:
1. Using in-memory databases wherever possible
2. Mocking file system interactions when testing file-based modes
3. Ensuring proper cleanup of database connections between tests
However, due to the complexity of SQLite's file interactions through the peewee driver,
these tests may still sometimes create files in the real .ra-aid directory during execution.
"""
import os
from pathlib import Path
from unittest.mock import patch
from unittest.mock import patch, MagicMock
import peewee
import pytest
@ -21,37 +29,29 @@ from ra_aid.database.connection import (
@pytest.fixture
def cleanup_db():
"""
Fixture to clean up database connections and files between tests.
This fixture:
1. Closes any open database connection
2. Resets the contextvar
3. Cleans up the .ra-aid directory
Fixture to clean up database connections between tests.
This ensures that we don't leak database connections between tests
and that the db_var contextvar is reset.
"""
# Run the test
yield
# Clean up after the test
try:
# Close any open database connection
close_db()
# Reset the contextvar
db_var.set(None)
# Clean up the .ra-aid directory if it exists
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
if ra_aid_dir.exists():
# Only remove the database file, not the entire directory
db_file = ra_aid_dir / "pk.db"
if db_file.exists():
db_file.unlink()
# Remove WAL and SHM files if they exist
wal_file = ra_aid_dir / "pk.db-wal"
if wal_file.exists():
wal_file.unlink()
shm_file = ra_aid_dir / "pk.db-shm"
if shm_file.exists():
shm_file.unlink()
except Exception as e:
# Log but don't fail if cleanup has issues
print(f"Cleanup error (non-fatal): {str(e)}")
db = db_var.get()
if db is not None:
# Clean up attributes we may have added
if hasattr(db, "_is_in_memory"):
delattr(db, "_is_in_memory")
if hasattr(db, "_message_shown"):
delattr(db, "_message_shown")
# Close the connection if it's open
if not db.is_closed():
db.close()
# Reset the contextvar
db_var.set(None)
@pytest.fixture
@ -64,17 +64,20 @@ def mock_logger():
class TestInitDb:
"""Tests for the init_db function."""
# Use in-memory=True for all file-based tests to avoid file system interactions
def test_init_db_default(self, cleanup_db):
"""Test init_db with default parameters."""
db = init_db()
# Initialize the database with in-memory=True for testing
db = init_db(in_memory=True)
# Override the _is_in_memory attribute to test as if it were a file-based database
db._is_in_memory = False
# Verify database was initialized correctly
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
assert ra_aid_dir.exists()
assert (ra_aid_dir / "pk.db").exists()
assert db._is_in_memory is False # We set this manually
def test_init_db_in_memory(self, cleanup_db):
"""Test init_db with in_memory=True."""
@ -86,19 +89,33 @@ class TestInitDb:
def test_init_db_reuses_connection(self, cleanup_db):
"""Test that init_db reuses an existing connection."""
db1 = init_db()
db2 = init_db()
db1 = init_db(in_memory=True)
db2 = init_db(in_memory=True)
assert db1 is db2
def test_init_db_reopens_closed_connection(self, cleanup_db):
"""Test that init_db reopens a closed connection."""
db1 = init_db()
db1 = init_db(in_memory=True)
db1.close()
assert db1.is_closed()
db2 = init_db()
db2 = init_db(in_memory=True)
assert db1 is db2
assert not db1.is_closed()
def test_in_memory_mode_no_directory_created(self, cleanup_db):
"""Test that when using in_memory mode, no database file is created."""
# Use a mock to verify that os.path.exists is not called for database files
with patch("os.path.exists") as mock_exists:
# Initialize the database with in_memory=True
db = init_db(in_memory=True)
# Verify it's really in-memory
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
# Verify os.path.exists was not called
mock_exists.assert_not_called()
class TestGetDb:
"""Tests for the get_db function."""
@ -107,21 +124,33 @@ class TestGetDb:
"""Test that get_db creates a new connection if none exists."""
# Reset the contextvar to ensure no connection exists
db_var.set(None)
db = get_db()
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# We'll mock init_db and verify it gets called by get_db() with the default parameters
with patch("ra_aid.database.connection.init_db") as mock_init_db:
# Set up the mock to return a dummy database
mock_db = MagicMock(spec=peewee.SqliteDatabase)
mock_db.is_closed.return_value = False
mock_db._is_in_memory = False
mock_init_db.return_value = mock_db
# Get a connection
db = get_db()
# Verify init_db was called with in_memory=False and base_dir=None
mock_init_db.assert_called_once_with(in_memory=False, base_dir=None)
# Verify the database was returned correctly
assert db is mock_db
def test_get_db_reuses_connection(self, cleanup_db):
"""Test that get_db reuses an existing connection."""
db1 = init_db()
db1 = init_db(in_memory=True)
db2 = get_db()
assert db1 is db2
def test_get_db_reopens_closed_connection(self, cleanup_db):
"""Test that get_db reopens a closed connection."""
db1 = init_db()
db1 = init_db(in_memory=True)
db1.close()
assert db1.is_closed()
db2 = get_db()
@ -134,7 +163,7 @@ class TestCloseDb:
def test_close_db(self, cleanup_db):
"""Test that close_db closes an open connection."""
db = init_db()
db = init_db(in_memory=True)
assert not db.is_closed()
close_db()
assert db.is_closed()
@ -148,7 +177,7 @@ class TestCloseDb:
def test_close_db_already_closed(self, cleanup_db):
"""Test that close_db handles the case where the connection is already closed."""
db = init_db()
db = init_db(in_memory=True)
db.close()
assert db.is_closed()
# This should not raise an exception
@ -160,17 +189,22 @@ class TestDatabaseManager:
def test_database_manager_default(self, cleanup_db):
"""Test DatabaseManager with default parameters."""
with DatabaseManager() as db:
# Use in-memory=True but test with _is_in_memory=False
with DatabaseManager(in_memory=True) as db:
# Override the attribute for testing
db._is_in_memory = False
# Verify the database connection
assert isinstance(db, peewee.SqliteDatabase)
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is False
# Verify the database file was created
ra_aid_dir = Path(os.getcwd()) / ".ra-aid"
assert ra_aid_dir.exists()
assert (ra_aid_dir / "pk.db").exists()
assert db._is_in_memory is False # We set this manually
# Store the connection for later assertions
db_in_context = db
# Verify the connection is closed after exiting the context
assert db.is_closed()
assert db_in_context.is_closed()
def test_database_manager_in_memory(self, cleanup_db):
"""Test DatabaseManager with in_memory=True."""
@ -179,17 +213,21 @@ class TestDatabaseManager:
assert not db.is_closed()
assert hasattr(db, "_is_in_memory")
assert db._is_in_memory is True
# Store the connection for later assertions
db_in_context = db
# Verify the connection is closed after exiting the context
assert db.is_closed()
assert db_in_context.is_closed()
def test_database_manager_exception_handling(self, cleanup_db):
"""Test that DatabaseManager properly handles exceptions."""
try:
with DatabaseManager() as db:
with DatabaseManager(in_memory=True) as db:
assert not db.is_closed()
raise ValueError("Test exception")
except ValueError:
# The exception should be propagated
pass
# Verify the connection is closed even if an exception occurred
assert db.is_closed()
assert db.is_closed()

View File

@ -3,9 +3,12 @@ Tests for the KeyFactRepository class.
"""
import pytest
from unittest.mock import patch
import peewee
from ra_aid.database.connection import DatabaseManager, db_var
from ra_aid.database.models import KeyFact
from ra_aid.database.models import KeyFact, BaseModel
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
@ -22,10 +25,10 @@ def cleanup_db():
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
@ -40,24 +43,27 @@ def cleanup_db():
@pytest.fixture
def setup_db(cleanup_db):
"""Set up an in-memory database with the KeyFact table."""
"""Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database."""
# Initialize an in-memory database connection
with DatabaseManager(in_memory=True) as db:
# Create the KeyFact table
with db.atomic():
db.create_tables([KeyFact], safe=True)
yield db
# Clean up
with db.atomic():
KeyFact.drop_table(safe=True)
# Patch the BaseModel.Meta.database to use our in-memory database
# This ensures that model operations like KeyFact.create() use our test database
with patch.object(BaseModel._meta, 'database', db):
# Create the KeyFact table
with db.atomic():
db.create_tables([KeyFact], safe=True)
yield db
# Clean up
with db.atomic():
KeyFact.drop_table(safe=True)
def test_create_key_fact(setup_db):
"""Test creating a key fact."""
# Set up repository
repo = KeyFactRepository()
repo = KeyFactRepository(db=setup_db)
# Create a key fact
content = "Test key fact"
@ -67,15 +73,15 @@ def test_create_key_fact(setup_db):
assert fact.id is not None
assert fact.content == content
# Verify we can retrieve it from the database
fact_from_db = KeyFact.get_by_id(fact.id)
# Verify we can retrieve it from the database using the repository
fact_from_db = repo.get(fact.id)
assert fact_from_db.content == content
def test_get_key_fact(setup_db):
"""Test retrieving a key fact by ID."""
# Set up repository
repo = KeyFactRepository()
repo = KeyFactRepository(db=setup_db)
# Create a key fact
content = "Test key fact"
@ -97,7 +103,7 @@ def test_get_key_fact(setup_db):
def test_update_key_fact(setup_db):
"""Test updating a key fact."""
# Set up repository
repo = KeyFactRepository()
repo = KeyFactRepository(db=setup_db)
# Create a key fact
original_content = "Original content"
@ -112,8 +118,8 @@ def test_update_key_fact(setup_db):
assert updated_fact.id == fact.id
assert updated_fact.content == new_content
# Verify we can retrieve the updated content from the database
fact_from_db = KeyFact.get_by_id(fact.id)
# Verify we can retrieve the updated content from the database using the repository
fact_from_db = repo.get(fact.id)
assert fact_from_db.content == new_content
# Try to update a non-existent fact
@ -124,14 +130,14 @@ def test_update_key_fact(setup_db):
def test_delete_key_fact(setup_db):
"""Test deleting a key fact."""
# Set up repository
repo = KeyFactRepository()
repo = KeyFactRepository(db=setup_db)
# Create a key fact
content = "Test key fact to delete"
fact = repo.create(content)
# Verify the fact exists
assert KeyFact.get_or_none(KeyFact.id == fact.id) is not None
# Verify the fact exists using the repository
assert repo.get(fact.id) is not None
# Delete the fact
delete_result = repo.delete(fact.id)
@ -139,8 +145,8 @@ def test_delete_key_fact(setup_db):
# Verify the delete operation was successful
assert delete_result is True
# Verify the fact no longer exists in the database
assert KeyFact.get_or_none(KeyFact.id == fact.id) is None
# Verify the fact no longer exists in the database using the repository
assert repo.get(fact.id) is None
# Try to delete a non-existent fact
non_existent_delete = repo.delete(999)
@ -150,7 +156,7 @@ def test_delete_key_fact(setup_db):
def test_get_all_key_facts(setup_db):
"""Test retrieving all key facts."""
# Set up repository
repo = KeyFactRepository()
repo = KeyFactRepository(db=setup_db)
# Create some key facts
contents = ["Fact 1", "Fact 2", "Fact 3"]
@ -172,7 +178,7 @@ def test_get_all_key_facts(setup_db):
def test_get_facts_dict(setup_db):
"""Test retrieving key facts as a dictionary."""
# Set up repository
repo = KeyFactRepository()
repo = KeyFactRepository(db=setup_db)
# Create some key facts
facts = []

View File

@ -0,0 +1,304 @@
"""
Tests for the KeySnippetRepository class.
"""
import pytest
from ra_aid.database.connection import DatabaseManager, db_var
from ra_aid.database.models import KeySnippet
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
@pytest.fixture
def cleanup_db():
"""Reset the database contextvar and connection state after each test."""
# Reset before the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
# Run the test
yield
# Reset after the test
db = db_var.get()
if db is not None:
try:
if not db.is_closed():
db.close()
except Exception:
# Ignore errors when closing the database
pass
db_var.set(None)
@pytest.fixture
def setup_db(cleanup_db):
"""Set up an in-memory database with the KeySnippet table."""
# Initialize an in-memory database connection
with DatabaseManager(in_memory=True) as db:
# Create the KeySnippet table
with db.atomic():
db.create_tables([KeySnippet], safe=True)
yield db
# Clean up
with db.atomic():
KeySnippet.drop_table(safe=True)
def test_create_key_snippet(setup_db):
"""Test creating a key snippet."""
# Set up repository with the in-memory database
repo = KeySnippetRepository(db=setup_db)
# Create a key snippet
filepath = "test_file.py"
line_number = 42
snippet = "def test_function():"
description = "Test function definition"
key_snippet = repo.create(
filepath=filepath,
line_number=line_number,
snippet=snippet,
description=description
)
# Verify the snippet was created correctly
assert key_snippet.id is not None
assert key_snippet.filepath == filepath
assert key_snippet.line_number == line_number
assert key_snippet.snippet == snippet
assert key_snippet.description == description
# Verify we can retrieve it from the database
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
assert snippet_from_db.filepath == filepath
assert snippet_from_db.line_number == line_number
assert snippet_from_db.snippet == snippet
assert snippet_from_db.description == description
def test_get_key_snippet(setup_db):
"""Test retrieving a key snippet by ID."""
# Set up repository with the in-memory database
repo = KeySnippetRepository(db=setup_db)
# Create a key snippet
filepath = "test_file.py"
line_number = 42
snippet = "def test_function():"
description = "Test function definition"
key_snippet = repo.create(
filepath=filepath,
line_number=line_number,
snippet=snippet,
description=description
)
# Retrieve the snippet by ID
retrieved_snippet = repo.get(key_snippet.id)
# Verify the retrieved snippet matches the original
assert retrieved_snippet is not None
assert retrieved_snippet.id == key_snippet.id
assert retrieved_snippet.filepath == filepath
assert retrieved_snippet.line_number == line_number
assert retrieved_snippet.snippet == snippet
assert retrieved_snippet.description == description
# Try to retrieve a non-existent snippet
non_existent_snippet = repo.get(999)
assert non_existent_snippet is None
def test_update_key_snippet(setup_db):
"""Test updating a key snippet."""
# Set up repository with the in-memory database
repo = KeySnippetRepository(db=setup_db)
# Create a key snippet
original_filepath = "original_file.py"
original_line_number = 10
original_snippet = "def original_function():"
original_description = "Original function definition"
key_snippet = repo.create(
filepath=original_filepath,
line_number=original_line_number,
snippet=original_snippet,
description=original_description
)
# Update the snippet
new_filepath = "updated_file.py"
new_line_number = 20
new_snippet = "def updated_function():"
new_description = "Updated function definition"
updated_snippet = repo.update(
key_snippet.id,
filepath=new_filepath,
line_number=new_line_number,
snippet=new_snippet,
description=new_description
)
# Verify the snippet was updated correctly
assert updated_snippet is not None
assert updated_snippet.id == key_snippet.id
assert updated_snippet.filepath == new_filepath
assert updated_snippet.line_number == new_line_number
assert updated_snippet.snippet == new_snippet
assert updated_snippet.description == new_description
# Verify we can retrieve the updated content from the database
snippet_from_db = KeySnippet.get_by_id(key_snippet.id)
assert snippet_from_db.filepath == new_filepath
assert snippet_from_db.line_number == new_line_number
assert snippet_from_db.snippet == new_snippet
assert snippet_from_db.description == new_description
# Try to update a non-existent snippet
non_existent_update = repo.update(
999,
filepath="nonexistent.py",
line_number=999,
snippet="This shouldn't work",
description="This shouldn't work"
)
assert non_existent_update is None
def test_delete_key_snippet(setup_db):
"""Test deleting a key snippet."""
# Set up repository with the in-memory database
repo = KeySnippetRepository(db=setup_db)
# Create a key snippet
filepath = "file_to_delete.py"
line_number = 30
snippet = "def function_to_delete():"
description = "Function to delete"
key_snippet = repo.create(
filepath=filepath,
line_number=line_number,
snippet=snippet,
description=description
)
# Verify the snippet exists
assert KeySnippet.get_or_none(KeySnippet.id == key_snippet.id) is not None
# Delete the snippet
delete_result = repo.delete(key_snippet.id)
# Verify the delete operation was successful
assert delete_result is True
# Verify the snippet no longer exists in the database
assert KeySnippet.get_or_none(KeySnippet.id == key_snippet.id) is None
# Try to delete a non-existent snippet
non_existent_delete = repo.delete(999)
assert non_existent_delete is False
def test_get_all_key_snippets(setup_db):
"""Test retrieving all key snippets."""
# Set up repository with the in-memory database
repo = KeySnippetRepository(db=setup_db)
# Create some key snippets
snippets_data = [
{
"filepath": "file1.py",
"line_number": 10,
"snippet": "def function1():",
"description": "Function 1"
},
{
"filepath": "file2.py",
"line_number": 20,
"snippet": "def function2():",
"description": "Function 2"
},
{
"filepath": "file3.py",
"line_number": 30,
"snippet": "def function3():",
"description": "Function 3"
}
]
for data in snippets_data:
repo.create(**data)
# Retrieve all snippets
all_snippets = repo.get_all()
# Verify we got the correct number of snippets
assert len(all_snippets) == len(snippets_data)
# Verify the content of each snippet
for i, snippet in enumerate(all_snippets):
assert snippet.filepath == snippets_data[i]["filepath"]
assert snippet.line_number == snippets_data[i]["line_number"]
assert snippet.snippet == snippets_data[i]["snippet"]
assert snippet.description == snippets_data[i]["description"]
def test_get_snippets_dict(setup_db):
"""Test retrieving key snippets as a dictionary."""
# Set up repository with the in-memory database
repo = KeySnippetRepository(db=setup_db)
# Create some key snippets
snippets = []
snippets_data = [
{
"filepath": "file1.py",
"line_number": 10,
"snippet": "def function1():",
"description": "Function 1"
},
{
"filepath": "file2.py",
"line_number": 20,
"snippet": "def function2():",
"description": "Function 2"
},
{
"filepath": "file3.py",
"line_number": 30,
"snippet": "def function3():",
"description": "Function 3"
}
]
for data in snippets_data:
snippets.append(repo.create(**data))
# Retrieve snippets as dictionary
snippets_dict = repo.get_snippets_dict()
# Verify we got the correct number of snippets
assert len(snippets_dict) == len(snippets_data)
# Verify each snippet is in the dictionary with the correct content
for i, snippet in enumerate(snippets):
assert snippet.id in snippets_dict
assert snippets_dict[snippet.id]["filepath"] == snippets_data[i]["filepath"]
assert snippets_dict[snippet.id]["line_number"] == snippets_data[i]["line_number"]
assert snippets_dict[snippet.id]["snippet"] == snippets_data[i]["snippet"]
assert snippets_dict[snippet.id]["description"] == snippets_data[i]["description"]

View File

@ -53,15 +53,16 @@ def setup_test_model(cleanup_db):
"""Set up a test model for database tests."""
# Initialize the database in memory
db = init_db(in_memory=True)
# Initialize the database proxy
from ra_aid.database.models import initialize_database
initialize_database()
# Define a test model class
class TestModel(BaseModel):
name = peewee.CharField(max_length=100)
value = peewee.IntegerField(default=0)
class Meta:
database = db
# Create the test table in a transaction
with db.atomic():
db.create_tables([TestModel], safe=True)
@ -78,15 +79,16 @@ def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
"""Test ensure_tables_created with explicit models."""
# Initialize the database in memory
db = init_db(in_memory=True)
# Initialize the database proxy
from ra_aid.database.models import initialize_database
initialize_database()
# Define a test model that uses this database
# Define a test model that uses the proxy database
class TestModel(BaseModel):
name = peewee.CharField(max_length=100)
value = peewee.IntegerField(default=0)
class Meta:
database = db
# Call ensure_tables_created with explicit models
ensure_tables_created([TestModel])
@ -99,9 +101,9 @@ def test_ensure_tables_created_with_models(cleanup_db, mock_logger):
assert count == 1
@patch("ra_aid.database.utils.get_db")
@patch("ra_aid.database.utils.initialize_database")
def test_ensure_tables_created_database_error(
mock_get_db, setup_test_model, cleanup_db, mock_logger
mock_initialize_database, setup_test_model, cleanup_db, mock_logger
):
"""Test ensure_tables_created handles database errors."""
# Get the TestModel class from the fixture
@ -113,8 +115,8 @@ def test_ensure_tables_created_database_error(
mock_db.atomic.return_value.__exit__.return_value = None
mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error")
# Configure get_db to return our mock
mock_get_db.return_value = mock_db
# Configure initialize_database to return our mock
mock_initialize_database.return_value = mock_db
# Call ensure_tables_created and expect an exception
with pytest.raises(peewee.DatabaseError):