From 038e7b886c7c0db5ece703cab50cbfd2758cc1db Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Sun, 2 Mar 2025 19:06:51 -0500 Subject: [PATCH] key snippets db --- ra_aid/__main__.py | 15 + ra_aid/agent_utils.py | 2 +- ra_aid/database/__init__.py | 3 +- ra_aid/database/connection.py | 46 +- ra_aid/database/models.py | 37 +- .../repositories/key_fact_repository.py | 21 +- .../repositories/key_snippet_repository.py | 214 +++++ .../tests/ra_aid/database/test_connection.py | 835 ++++++------------ ra_aid/database/utils.py | 6 +- ...3_20250302_163752_add_key_snippet_model.py | 57 ++ .../key_snippets_formatter.py | 86 ++ ra_aid/tools/memory.py | 131 ++- tests/conftest.py | 45 + tests/ra_aid/database/test_connection.py | 152 ++-- .../database/test_key_fact_repository.py | 60 +- .../database/test_key_snippet_repository.py | 304 +++++++ tests/ra_aid/database/test_utils.py | 24 +- 17 files changed, 1320 insertions(+), 718 deletions(-) create mode 100644 ra_aid/database/repositories/key_snippet_repository.py create mode 100644 ra_aid/migrations/003_20250302_163752_add_key_snippet_model.py create mode 100644 ra_aid/model_formatters/key_snippets_formatter.py create mode 100644 tests/conftest.py create mode 100644 tests/ra_aid/database/test_key_snippet_repository.py diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 36111b6..8d211e1 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -610,6 +610,9 @@ def main(): memory=planning_memory, config=config, ) + + # Run cleanup tasks before exiting database context + run_cleanup() except (KeyboardInterrupt, AgentInterrupt): print() @@ -618,5 +621,17 @@ def main(): sys.exit(0) +def run_cleanup(): + """Run cleanup tasks after main execution.""" + try: + # Import the key facts cleaner agent + from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent + + # Run the key facts garbage collection agent regardless of the number of facts + run_key_facts_gc_agent() + except Exception as e: + logger.error(f"Failed to run cleanup tasks: {str(e)}") + + if __name__ == "__main__": main() \ No newline at end of file diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 650c9f9..35d5f07 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -419,7 +419,7 @@ def run_research_agent( project_info=formatted_project_info, new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "", ) - + config = _global_memory.get("config", {}) if not config else config recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) run_config = { diff --git a/ra_aid/database/__init__.py b/ra_aid/database/__init__.py index ca9eefc..835f043 100644 --- a/ra_aid/database/__init__.py +++ b/ra_aid/database/__init__.py @@ -13,7 +13,7 @@ from ra_aid.database.migrations import ( get_migration_status, init_migrations, ) -from ra_aid.database.models import BaseModel +from ra_aid.database.models import BaseModel, initialize_database from ra_aid.database.utils import ensure_tables_created, get_model_count, truncate_table __all__ = [ @@ -22,6 +22,7 @@ __all__ = [ "close_db", "DatabaseManager", "BaseModel", + "initialize_database", "get_model_count", "truncate_table", "ensure_tables_created", diff --git a/ra_aid/database/connection.py b/ra_aid/database/connection.py index 6df52f8..635442b 100644 --- a/ra_aid/database/connection.py +++ b/ra_aid/database/connection.py @@ -14,6 +14,9 @@ import peewee from ra_aid.logging_config import get_logger +# Import initialize_database after it's defined in models.py +# We need to do the import inside functions to avoid circular imports + # Create contextvar to hold the database connection db_var = contextvars.ContextVar("db", default=None) logger = get_logger(__name__) @@ -34,16 +37,23 @@ class DatabaseManager: # Or with in-memory database: with DatabaseManager(in_memory=True) as db: # Use in-memory database + + # Or with custom base directory: + with DatabaseManager(base_dir="/custom/path") as db: + # Use database in custom directory """ - def __init__(self, in_memory: bool = False): + def __init__(self, in_memory: bool = False, base_dir: Optional[str] = None): """ Initialize the DatabaseManager. Args: in_memory: Whether to use an in-memory database (default: False) + base_dir: Optional base directory to use instead of current working directory. + If None, uses os.getcwd() (default: None) """ self.in_memory = in_memory + self.base_dir = base_dir def __enter__(self) -> peewee.SqliteDatabase: """ @@ -52,7 +62,19 @@ class DatabaseManager: Returns: peewee.SqliteDatabase: The initialized database connection """ - return init_db(in_memory=self.in_memory) + db = init_db(in_memory=self.in_memory, base_dir=self.base_dir) + + # Initialize the database proxy in models.py + try: + # Import here to avoid circular imports + from ra_aid.database.models import initialize_database + initialize_database() + except ImportError as e: + logger.error(f"Failed to import initialize_database: {str(e)}") + except Exception as e: + logger.error(f"Error initializing database proxy: {str(e)}") + + return db def __exit__( self, @@ -74,7 +96,7 @@ class DatabaseManager: return False -def init_db(in_memory: bool = False) -> peewee.SqliteDatabase: +def init_db(in_memory: bool = False, base_dir: Optional[str] = None) -> peewee.SqliteDatabase: """ Initialize the database connection. @@ -84,6 +106,8 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase: Args: in_memory: Whether to use an in-memory database (default: False) + base_dir: Optional base directory to use instead of current working directory. + If None, uses os.getcwd() (default: None) Returns: peewee.SqliteDatabase: The initialized database connection @@ -110,9 +134,9 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase: db_path = ":memory:" logger.debug("Using in-memory SQLite database") else: - # Get current working directory and create .ra-aid directory if it doesn't exist - cwd = os.getcwd() - logger.debug(f"Current working directory: {cwd}") + # Get base directory (use current working directory if not provided) + cwd = base_dir if base_dir is not None else os.getcwd() + logger.debug(f"Base directory for database: {cwd}") # Define the .ra-aid directory path ra_aid_dir_str = os.path.join(cwd, ".ra-aid") @@ -300,13 +324,17 @@ def init_db(in_memory: bool = False) -> peewee.SqliteDatabase: raise -def get_db() -> peewee.SqliteDatabase: +def get_db(base_dir: Optional[str] = None) -> peewee.SqliteDatabase: """ Get the current database connection. If no connection exists, initializes a new one. If connection exists but is closed, reopens it. + Args: + base_dir: Optional base directory to use instead of current working directory. + If None, uses os.getcwd() (default: None) + Returns: peewee.SqliteDatabase: The current database connection """ @@ -315,7 +343,7 @@ def get_db() -> peewee.SqliteDatabase: if db is None: # No database connection exists, initialize one # Use the default in-memory mode (False) - return init_db(in_memory=False) + return init_db(in_memory=False, base_dir=base_dir) # Check if connection is closed and reopen if needed if db.is_closed(): @@ -332,7 +360,7 @@ def get_db() -> peewee.SqliteDatabase: in_memory = hasattr(db, "_is_in_memory") and db._is_in_memory logger.debug(f"Creating new database connection (in_memory={in_memory})") # Create a completely new database object, don't reuse the old one - return init_db(in_memory=in_memory) + return init_db(in_memory=in_memory, base_dir=base_dir) return db diff --git a/ra_aid/database/models.py b/ra_aid/database/models.py index e9875c4..826fc97 100644 --- a/ra_aid/database/models.py +++ b/ra_aid/database/models.py @@ -15,6 +15,21 @@ from ra_aid.logging_config import get_logger T = TypeVar("T", bound="BaseModel") logger = get_logger(__name__) +# Create a database proxy that will be initialized later +database_proxy = peewee.DatabaseProxy() + + +def initialize_database(): + """ + Initialize the database proxy with a real database connection. + + This function should be called before any database operations + to ensure the proxy points to a real database connection. + """ + db = get_db() + database_proxy.initialize(db) + return db + class BaseModel(peewee.Model): """ @@ -28,7 +43,7 @@ class BaseModel(peewee.Model): updated_at = peewee.DateTimeField(default=datetime.datetime.now) class Meta: - database = get_db() + database = database_proxy def save(self, *args: Any, **kwargs: Any) -> int: """ @@ -75,4 +90,22 @@ class KeyFact(BaseModel): # created_at and updated_at are inherited from BaseModel class Meta: - table_name = "key_fact" \ No newline at end of file + table_name = "key_fact" + + +class KeySnippet(BaseModel): + """ + Model representing a key code snippet stored in the database. + + Key snippets are important code fragments from the project that need to be + referenced later. Each snippet includes its file location, line number, + the code content itself, and an optional description of its significance. + """ + filepath = peewee.TextField() + line_number = peewee.IntegerField() + snippet = peewee.TextField() + description = peewee.TextField(null=True) + # created_at and updated_at are inherited from BaseModel + + class Meta: + table_name = "key_snippet" \ No newline at end of file diff --git a/ra_aid/database/repositories/key_fact_repository.py b/ra_aid/database/repositories/key_fact_repository.py index 1d0e38a..1dfa824 100644 --- a/ra_aid/database/repositories/key_fact_repository.py +++ b/ra_aid/database/repositories/key_fact_repository.py @@ -10,7 +10,7 @@ from typing import Dict, List, Optional import peewee from ra_aid.database.connection import get_db -from ra_aid.database.models import KeyFact +from ra_aid.database.models import KeyFact, initialize_database from ra_aid.logging_config import get_logger logger = get_logger(__name__) @@ -29,6 +29,15 @@ class KeyFactRepository: all_facts = repo.get_all() """ + def __init__(self, db=None): + """ + Initialize the repository with an optional database connection. + + Args: + db: Optional database connection to use. If None, will use initialize_database() + """ + self.db = db + def create(self, content: str) -> KeyFact: """ Create a new key fact in the database. @@ -43,7 +52,7 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error creating the fact """ try: - db = get_db() + db = self.db if self.db is not None else initialize_database() fact = KeyFact.create(content=content) logger.debug(f"Created key fact ID {fact.id}: {content}") return fact @@ -65,7 +74,7 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error accessing the database """ try: - db = get_db() + db = self.db if self.db is not None else initialize_database() return KeyFact.get_or_none(KeyFact.id == fact_id) except peewee.DatabaseError as e: logger.error(f"Failed to fetch key fact {fact_id}: {str(e)}") @@ -86,7 +95,7 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error updating the fact """ try: - db = get_db() + db = self.db if self.db is not None else initialize_database() # First check if the fact exists fact = self.get(fact_id) if not fact: @@ -116,7 +125,7 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error deleting the fact """ try: - db = get_db() + db = self.db if self.db is not None else initialize_database() # First check if the fact exists fact = self.get(fact_id) if not fact: @@ -142,7 +151,7 @@ class KeyFactRepository: peewee.DatabaseError: If there's an error accessing the database """ try: - db = get_db() + db = self.db if self.db is not None else initialize_database() return list(KeyFact.select().order_by(KeyFact.id)) except peewee.DatabaseError as e: logger.error(f"Failed to fetch all key facts: {str(e)}") diff --git a/ra_aid/database/repositories/key_snippet_repository.py b/ra_aid/database/repositories/key_snippet_repository.py new file mode 100644 index 0000000..f4d6f2c --- /dev/null +++ b/ra_aid/database/repositories/key_snippet_repository.py @@ -0,0 +1,214 @@ +""" +Key snippet repository implementation for database access. + +This module provides a repository implementation for the KeySnippet model, +following the repository pattern for data access abstraction. +""" + +from typing import Dict, List, Optional, Any + +import peewee + +from ra_aid.database.connection import get_db +from ra_aid.database.models import KeySnippet, initialize_database +from ra_aid.logging_config import get_logger + +logger = get_logger(__name__) + + +class KeySnippetRepository: + """ + Repository for managing KeySnippet database operations. + + This class provides methods for performing CRUD operations on the KeySnippet model, + abstracting the database access details from the business logic. + + Example: + repo = KeySnippetRepository() + snippet = repo.create( + filepath="main.py", + line_number=42, + snippet="def hello_world():", + description="Main function definition" + ) + all_snippets = repo.get_all() + """ + + def __init__(self, db=None): + """ + Initialize the repository with an optional database connection. + + Args: + db: Optional database connection to use. If None, will use initialize_database() + """ + self.db = db + + def create( + self, filepath: str, line_number: int, snippet: str, description: Optional[str] = None + ) -> KeySnippet: + """ + Create a new key snippet in the database. + + Args: + filepath: Path to the source file + line_number: Line number where the snippet starts + snippet: The source code snippet text + description: Optional description of the significance + + Returns: + KeySnippet: The newly created key snippet instance + + Raises: + peewee.DatabaseError: If there's an error creating the snippet + """ + try: + db = self.db if self.db is not None else initialize_database() + key_snippet = KeySnippet.create( + filepath=filepath, + line_number=line_number, + snippet=snippet, + description=description + ) + logger.debug(f"Created key snippet ID {key_snippet.id}: {filepath}:{line_number}") + return key_snippet + except peewee.DatabaseError as e: + logger.error(f"Failed to create key snippet: {str(e)}") + raise + + def get(self, snippet_id: int) -> Optional[KeySnippet]: + """ + Retrieve a key snippet by its ID. + + Args: + snippet_id: The ID of the key snippet to retrieve + + Returns: + Optional[KeySnippet]: The key snippet instance if found, None otherwise + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + db = self.db if self.db is not None else initialize_database() + return KeySnippet.get_or_none(KeySnippet.id == snippet_id) + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch key snippet {snippet_id}: {str(e)}") + raise + + def update( + self, + snippet_id: int, + filepath: str, + line_number: int, + snippet: str, + description: Optional[str] = None + ) -> Optional[KeySnippet]: + """ + Update an existing key snippet. + + Args: + snippet_id: The ID of the key snippet to update + filepath: Path to the source file + line_number: Line number where the snippet starts + snippet: The source code snippet text + description: Optional description of the significance + + Returns: + Optional[KeySnippet]: The updated key snippet if found, None otherwise + + Raises: + peewee.DatabaseError: If there's an error updating the snippet + """ + try: + db = self.db if self.db is not None else initialize_database() + # First check if the snippet exists + key_snippet = self.get(snippet_id) + if not key_snippet: + logger.warning(f"Attempted to update non-existent key snippet {snippet_id}") + return None + + # Update the snippet + key_snippet.filepath = filepath + key_snippet.line_number = line_number + key_snippet.snippet = snippet + key_snippet.description = description + key_snippet.save() + logger.debug(f"Updated key snippet ID {snippet_id}: {filepath}:{line_number}") + return key_snippet + except peewee.DatabaseError as e: + logger.error(f"Failed to update key snippet {snippet_id}: {str(e)}") + raise + + def delete(self, snippet_id: int) -> bool: + """ + Delete a key snippet by its ID. + + Args: + snippet_id: The ID of the key snippet to delete + + Returns: + bool: True if the snippet was deleted, False if it wasn't found + + Raises: + peewee.DatabaseError: If there's an error deleting the snippet + """ + try: + db = self.db if self.db is not None else initialize_database() + # First check if the snippet exists + key_snippet = self.get(snippet_id) + if not key_snippet: + logger.warning(f"Attempted to delete non-existent key snippet {snippet_id}") + return False + + # Delete the snippet + key_snippet.delete_instance() + logger.debug(f"Deleted key snippet ID {snippet_id}") + return True + except peewee.DatabaseError as e: + logger.error(f"Failed to delete key snippet {snippet_id}: {str(e)}") + raise + + def get_all(self) -> List[KeySnippet]: + """ + Retrieve all key snippets from the database. + + Returns: + List[KeySnippet]: List of all key snippet instances + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + db = self.db if self.db is not None else initialize_database() + return list(KeySnippet.select().order_by(KeySnippet.id)) + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch all key snippets: {str(e)}") + raise + + def get_snippets_dict(self) -> Dict[int, Dict[str, Any]]: + """ + Retrieve all key snippets as a dictionary mapping IDs to snippet information. + + This method is useful for compatibility with the existing memory format. + + Returns: + Dict[int, Dict[str, Any]]: Dictionary with snippet IDs as keys and + snippet information as values + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + snippets = self.get_all() + return { + snippet.id: { + "filepath": snippet.filepath, + "line_number": snippet.line_number, + "snippet": snippet.snippet, + "description": snippet.description + } + for snippet in snippets + } + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch key snippets as dictionary: {str(e)}") + raise \ No newline at end of file diff --git a/ra_aid/database/tests/ra_aid/database/test_connection.py b/ra_aid/database/tests/ra_aid/database/test_connection.py index b43ae52..85bad08 100644 --- a/ra_aid/database/tests/ra_aid/database/test_connection.py +++ b/ra_aid/database/tests/ra_aid/database/test_connection.py @@ -1,5 +1,8 @@ """ Tests for the database connection module. + +This file tests the database connection functionality using pytest's fixtures +for proper test isolation. """ import os @@ -21,642 +24,346 @@ from ra_aid.database.connection import ( @pytest.fixture def cleanup_db(): """ - Fixture to clean up database connections after tests. + Fixture to clean up database connections between tests. + + This ensures that we don't leak database connections between tests + and that the db_var contextvar is reset. """ # Run the test yield - + # Clean up after the test db = db_var.get() if db is not None: + # Clean up attributes we may have added if hasattr(db, "_is_in_memory"): delattr(db, "_is_in_memory") if hasattr(db, "_message_shown"): delattr(db, "_message_shown") + + # Close the connection if it's open if not db.is_closed(): db.close() - + # Reset the contextvar db_var.set(None) @pytest.fixture -def setup_in_memory_db(): +def db_path_mock(tmp_path, monkeypatch): """ - Fixture to set up an in-memory database for testing. + Fixture to mock os.getcwd() to return a temporary directory path. + + This ensures that all database operations use the temporary directory + and never touch the actual current working directory. """ - # Initialize in-memory database - db = init_db(in_memory=True) - - # Run the test - yield db - - # Clean up - if not db.is_closed(): - db.close() - db_var.set(None) - - -def test_init_db_creates_directory(cleanup_db, tmp_path): - """ - Test that init_db creates the .ra-aid directory if it doesn't exist. - """ - # Get and print the original working directory original_cwd = os.getcwd() - print(f"Original working directory: {original_cwd}") - - # Convert tmp_path to string for consistent handling tmp_path_str = str(tmp_path.absolute()) - print(f"Temporary directory path: {tmp_path_str}") - - # Change to the temporary directory - os.chdir(tmp_path_str) - current_cwd = os.getcwd() - print(f"Changed working directory to: {current_cwd}") - assert ( - current_cwd == tmp_path_str - ), f"Failed to change directory: {current_cwd} != {tmp_path_str}" - - # Create the .ra-aid directory manually to ensure it exists - ra_aid_path_str = os.path.join(current_cwd, ".ra-aid") - print(f"Creating .ra-aid directory at: {ra_aid_path_str}") - os.makedirs(ra_aid_path_str, exist_ok=True) - - # Verify the directory was created - assert os.path.exists( - ra_aid_path_str - ), f".ra-aid directory not found at {ra_aid_path_str}" - assert os.path.isdir( - ra_aid_path_str - ), f"{ra_aid_path_str} exists but is not a directory" - - # Create a test file to verify write permissions - test_file_path = os.path.join(ra_aid_path_str, "test_write.txt") - print(f"Creating test file to verify write permissions: {test_file_path}") - with open(test_file_path, "w") as f: - f.write("Test write permissions") - - # Verify the test file was created - assert os.path.exists(test_file_path), f"Test file not created at {test_file_path}" - - # Create an empty database file to ensure it exists before init_db - db_file_str = os.path.join(ra_aid_path_str, "pk.db") - print(f"Creating empty database file at: {db_file_str}") - with open(db_file_str, "w") as f: - f.write("") # Create empty file - - # Verify the database file was created - assert os.path.exists( - db_file_str - ), f"Empty database file not created at {db_file_str}" - print(f"Empty database file size: {os.path.getsize(db_file_str)} bytes") - - # Get directory permissions for debugging - dir_perms = oct(os.stat(ra_aid_path_str).st_mode)[-3:] - print(f"Directory permissions: {dir_perms}") - - # Initialize the database - print("Calling init_db()") - db = init_db() - print("init_db() returned successfully") - - # List contents of the current directory for debugging - print(f"Contents of current directory: {os.listdir(current_cwd)}") - - # List contents of the .ra-aid directory for debugging - print(f"Contents of .ra-aid directory: {os.listdir(ra_aid_path_str)}") - - # Check that the database file exists using os.path - assert os.path.exists(db_file_str), f"Database file not found at {db_file_str}" - assert os.path.isfile(db_file_str), f"{db_file_str} exists but is not a file" - print(f"Final database file size: {os.path.getsize(db_file_str)} bytes") - - -def test_init_db_creates_database_file(cleanup_db, tmp_path): - """ - Test that init_db creates the database file. - """ - # Change to the temporary directory - os.chdir(tmp_path) - - # Initialize the database - init_db() - - # Check that the database file was created - assert (tmp_path / ".ra-aid" / "pk.db").exists() - assert (tmp_path / ".ra-aid" / "pk.db").is_file() - - -def test_init_db_returns_database_connection(cleanup_db): - """ - Test that init_db returns a database connection. - """ - # Initialize the database - db = init_db() - - # Check that the database connection is returned - assert isinstance(db, peewee.SqliteDatabase) - assert not db.is_closed() - - -def test_init_db_with_in_memory_mode(cleanup_db): - """ - Test that init_db with in_memory=True creates an in-memory database. - """ - # Initialize the database in in-memory mode - db = init_db(in_memory=True) - - # Check that the database connection is returned - assert isinstance(db, peewee.SqliteDatabase) - assert not db.is_closed() - assert hasattr(db, "_is_in_memory") - assert db._is_in_memory is True - - -def test_in_memory_mode_no_directory_created(cleanup_db, tmp_path): - """ - Test that when using in-memory mode, no directory is created. - """ - # Change to the temporary directory - os.chdir(tmp_path) - - # Initialize the database in in-memory mode - init_db(in_memory=True) - - # Check that the .ra-aid directory was not created - # (Note: it might be created by other tests, so we can't assert it doesn't exist) - # Instead, check that the database file was not created - assert not (tmp_path / ".ra-aid" / "pk.db").exists() - - -def test_init_db_returns_existing_connection(cleanup_db): - """ - Test that init_db returns the existing connection if one exists. - """ - # Initialize the database - db1 = init_db() - - # Initialize the database again - db2 = init_db() - - # Check that the same connection is returned - assert db1 is db2 - - -def test_init_db_reopens_closed_connection(cleanup_db): - """ - Test that init_db reopens a closed connection. - """ - # Initialize the database - db1 = init_db() - - # Close the connection - db1.close() - - # Initialize the database again - db2 = init_db() - - # Check that the same connection is returned and it's open - assert db1 is db2 - assert not db1.is_closed() - - -def test_get_db_initializes_connection(cleanup_db): - """ - Test that get_db initializes a connection if none exists. - """ - # Get the database connection - db = get_db() - - # Check that a connection was initialized - assert isinstance(db, peewee.SqliteDatabase) - assert not db.is_closed() - - -def test_get_db_returns_existing_connection(cleanup_db): - """ - Test that get_db returns the existing connection if one exists. - """ - # Initialize the database - db1 = init_db() - - # Get the database connection - db2 = get_db() - - # Check that the same connection is returned - assert db1 is db2 - - -def test_get_db_reopens_closed_connection(cleanup_db): - """ - Test that get_db reopens a closed connection. - """ - # Initialize the database - db = init_db() - - # Close the connection - db.close() - - # Get the database connection - db2 = get_db() - - # Check that the same connection is returned and it's open - assert db is db2 - assert not db.is_closed() - - -def test_get_db_handles_reopen_error(cleanup_db, monkeypatch): - """ - Test that get_db handles errors when reopening a connection. - """ - # Initialize the database - db = init_db() - - # Close the connection - db.close() - - # Create a patched version of the connect method that raises an error - original_connect = peewee.SqliteDatabase.connect - - def mock_connect(self, *args, **kwargs): - if self is db: # Only raise for the specific db instance - raise peewee.OperationalError("Test error") - return original_connect(self, *args, **kwargs) - - # Apply the patch - monkeypatch.setattr(peewee.SqliteDatabase, "connect", mock_connect) - - # Get the database connection - db2 = get_db() - - # Check that a new connection was initialized - assert db is not db2 - assert not db2.is_closed() - - -def test_close_db_closes_connection(cleanup_db): - """ - Test that close_db closes the connection. - """ - # Initialize the database - db = init_db() - - # Close the connection - close_db() - - # Check that the connection is closed - assert db.is_closed() - - -def test_close_db_handles_no_connection(): - """ - Test that close_db handles the case where no connection exists. - """ - # Reset the contextvar - db_var.set(None) - - # Close the connection (should not raise an error) - close_db() - - -def test_close_db_handles_already_closed_connection(cleanup_db): - """ - Test that close_db handles the case where the connection is already closed. - """ - # Initialize the database - db = init_db() - - # Close the connection - db.close() - - # Close the connection again (should not raise an error) - close_db() - - -@patch("ra_aid.database.connection.peewee.SqliteDatabase.close") -def test_close_db_handles_error(mock_close, cleanup_db): - """ - Test that close_db handles errors when closing the connection. - """ - # Initialize the database - init_db() - - # Make close raise an error - mock_close.side_effect = peewee.DatabaseError("Test error") - - # Close the connection (should not raise an error) - close_db() - - -def test_database_manager_context_manager(cleanup_db): - """ - Test that DatabaseManager works as a context manager. - """ - # Use the context manager - with DatabaseManager() as db: - # Check that a connection was initialized - assert isinstance(db, peewee.SqliteDatabase) - assert not db.is_closed() - - # Store the connection for later - db_in_context = db - - # Check that the connection is closed after exiting the context - assert db_in_context.is_closed() - - -def test_database_manager_with_in_memory_mode(cleanup_db): - """ - Test that DatabaseManager with in_memory=True creates an in-memory database. - """ - # Use the context manager with in_memory=True - with DatabaseManager(in_memory=True) as db: - # Check that a connection was initialized - assert isinstance(db, peewee.SqliteDatabase) - assert not db.is_closed() - assert hasattr(db, "_is_in_memory") - assert db._is_in_memory is True - - -def test_init_db_shows_message_only_once(cleanup_db, caplog): - """ - Test that init_db only shows the initialization message once. - """ - # Initialize the database - init_db(in_memory=True) - - # Clear the log - caplog.clear() - - # Initialize the database again - init_db(in_memory=True) - - # Check that no message was logged - assert "database connection initialized" not in caplog.text.lower() - - -def test_init_db_sets_is_in_memory_attribute(cleanup_db): - """ - Test that init_db sets the _is_in_memory attribute. - """ - # Initialize the database with in_memory=False - db = init_db(in_memory=False) - - # Check that the _is_in_memory attribute is set to False - assert hasattr(db, "_is_in_memory") - assert db._is_in_memory is False - - # Reset the contextvar - db_var.set(None) - - # Initialize the database with in_memory=True - db = init_db(in_memory=True) - - # Check that the _is_in_memory attribute is set to True - assert hasattr(db, "_is_in_memory") - assert db._is_in_memory is True - - -""" -Tests for the database connection module. -""" - - -import pytest - - - -@pytest.fixture -def cleanup_db(): - """ - Fixture to clean up database connections and files between tests. - - This fixture: - 1. Closes any open database connection - 2. Resets the contextvar - 3. Cleans up the .ra-aid directory - """ - # Store the current working directory - original_cwd = os.getcwd() - - # Run the test - yield - - # Clean up after the test - try: - # Close any open database connection - close_db() - - # Reset the contextvar - db_var.set(None) - - # Clean up the .ra-aid directory if it exists - ra_aid_dir = Path(os.getcwd()) / ".ra-aid" - ra_aid_dir_str = str(ra_aid_dir.absolute()) - - # Check using both methods - path_exists = ra_aid_dir.exists() - os_exists = os.path.exists(ra_aid_dir_str) - - print(f"Cleanup check: Path.exists={path_exists}, os.path.exists={os_exists}") - - if os_exists: - # Only remove the database file, not the entire directory - db_file = os.path.join(ra_aid_dir_str, "pk.db") - if os.path.exists(db_file): - os.unlink(db_file) - - # Remove WAL and SHM files if they exist - wal_file = os.path.join(ra_aid_dir_str, "pk.db-wal") - if os.path.exists(wal_file): - os.unlink(wal_file) - - shm_file = os.path.join(ra_aid_dir_str, "pk.db-shm") - if os.path.exists(shm_file): - os.unlink(shm_file) - - # List remaining contents for debugging - if os.path.exists(ra_aid_dir_str): - print(f"Directory contents after cleanup: {os.listdir(ra_aid_dir_str)}") - except Exception as e: - # Log but don't fail if cleanup has issues - print(f"Cleanup error (non-fatal): {str(e)}") - - # Make sure we're back in the original directory + + # Create the .ra-aid directory in the temporary path + ra_aid_dir = tmp_path / ".ra-aid" + ra_aid_dir.mkdir(exist_ok=True) + + # Mock os.getcwd() to return the temporary directory + monkeypatch.setattr(os, "getcwd", lambda: tmp_path_str) + + yield tmp_path + + # Ensure we're back to the original directory after the test os.chdir(original_cwd) class TestInitDb: """Tests for the init_db function.""" - - def test_init_db_default(self, cleanup_db): - """Test init_db with default parameters.""" - # Get the absolute path of the current working directory - cwd = os.getcwd() - print(f"Current working directory: {cwd}") - - # Initialize the database - db = init_db() - - assert isinstance(db, peewee.SqliteDatabase) - assert not db.is_closed() - assert hasattr(db, "_is_in_memory") - assert db._is_in_memory is False - - # Verify the database file was created using both Path and os.path methods - ra_aid_dir = Path(cwd) / ".ra-aid" - ra_aid_dir_str = str(ra_aid_dir.absolute()) - - # Check directory existence using both methods - path_exists = ra_aid_dir.exists() - os_exists = os.path.exists(ra_aid_dir_str) - print(f"Directory check: Path.exists={path_exists}, os.path.exists={os_exists}") - - # List the contents of the current directory - print(f"Contents of {cwd}: {os.listdir(cwd)}") - - # If the directory exists, list its contents - if os_exists: - print(f"Contents of {ra_aid_dir_str}: {os.listdir(ra_aid_dir_str)}") - - # Use os.path for assertions to be more reliable - assert os.path.exists( - ra_aid_dir_str - ), f"Directory {ra_aid_dir_str} does not exist" - assert os.path.isdir(ra_aid_dir_str), f"{ra_aid_dir_str} is not a directory" - - db_file = os.path.join(ra_aid_dir_str, "pk.db") - assert os.path.exists(db_file), f"Database file {db_file} does not exist" - assert os.path.isfile(db_file), f"{db_file} is not a file" - + def test_init_db_in_memory(self, cleanup_db): """Test init_db with in_memory=True.""" + # Reset the contextvar to ensure a fresh start + db_var.set(None) db = init_db(in_memory=True) - + assert isinstance(db, peewee.SqliteDatabase) assert not db.is_closed() assert hasattr(db, "_is_in_memory") assert db._is_in_memory is True - - def test_init_db_reuses_connection(self, cleanup_db): - """Test that init_db reuses an existing connection.""" - db1 = init_db() - db2 = init_db() - - assert db1 is db2 - - def test_init_db_reopens_closed_connection(self, cleanup_db): - """Test that init_db reopens a closed connection.""" - db1 = init_db() - db1.close() - assert db1.is_closed() - - db2 = init_db() - assert db1 is db2 - assert not db1.is_closed() - - -class TestGetDb: - """Tests for the get_db function.""" - - def test_get_db_creates_connection(self, cleanup_db): - """Test that get_db creates a new connection if none exists.""" - # Reset the contextvar to ensure no connection exists - db_var.set(None) - - db = get_db() - + + def test_init_db_creates_directory(self, cleanup_db, db_path_mock): + """Test that init_db creates the .ra-aid directory if it doesn't exist.""" + # Remove the .ra-aid directory to test creation + ra_aid_dir = db_path_mock / ".ra-aid" + if ra_aid_dir.exists(): + for item in ra_aid_dir.iterdir(): + if item.is_file(): + item.unlink() + ra_aid_dir.rmdir() + + # Initialize the database + db = init_db() + + # Check that the directory was created + assert ra_aid_dir.exists() + assert ra_aid_dir.is_dir() assert isinstance(db, peewee.SqliteDatabase) assert not db.is_closed() assert hasattr(db, "_is_in_memory") assert db._is_in_memory is False - - def test_get_db_reuses_connection(self, cleanup_db): - """Test that get_db reuses an existing connection.""" - db1 = init_db() - db2 = get_db() - + + def test_init_db_creates_database_file(self, cleanup_db, db_path_mock): + """Test that init_db creates the database file.""" + # Initialize the database + init_db() + + # Check that the database file was created + assert (db_path_mock / ".ra-aid" / "pk.db").exists() + assert (db_path_mock / ".ra-aid" / "pk.db").is_file() + + def test_init_db_reuses_connection(self, cleanup_db): + """Test that init_db reuses an existing connection.""" + # Reset the contextvar to ensure a fresh start + db_var.set(None) + + # Use in_memory=True for this test to avoid touching the filesystem + db1 = init_db(in_memory=True) + db2 = init_db(in_memory=True) + assert db1 is db2 - - def test_get_db_reopens_closed_connection(self, cleanup_db): - """Test that get_db reopens a closed connection.""" - db1 = init_db() + + def test_init_db_reopens_closed_connection(self, cleanup_db): + """Test that init_db reopens a closed connection.""" + # Reset the contextvar to ensure a fresh start + db_var.set(None) + + # Use in_memory=True for this test to avoid touching the filesystem + db1 = init_db(in_memory=True) db1.close() assert db1.is_closed() + + db2 = init_db(in_memory=True) + assert db1 is db2 + assert not db1.is_closed() + + def test_in_memory_mode_no_directory_created(self, cleanup_db, db_path_mock): + """Test that when using in_memory mode, no database file is created.""" + # Initialize the database in in-memory mode + init_db(in_memory=True) + + # Check that the database file was not created + assert not (db_path_mock / ".ra-aid" / "pk.db").exists() + + def test_init_db_sets_is_in_memory_attribute(self, cleanup_db): + """Test that init_db sets the _is_in_memory attribute.""" + # Test with in_memory=True + db = init_db(in_memory=True) + assert hasattr(db, "_is_in_memory") + assert db._is_in_memory is True + + # Reset the contextvar + db_var.set(None) + + # Test with in_memory=False, but use a mocked directory + with patch("os.getcwd") as mock_getcwd: + temp_dir = Path("/tmp/testdb") + mock_getcwd.return_value = str(temp_dir) + + # Mock os.path.exists and os.makedirs to avoid filesystem operations + with patch("os.path.exists", return_value=True): + with patch("os.makedirs"): + with patch("os.path.isdir", return_value=True): + with patch.object(peewee.SqliteDatabase, "connect"): + with patch.object(peewee.SqliteDatabase, "execute_sql"): + db = init_db(in_memory=False) + assert hasattr(db, "_is_in_memory") + assert db._is_in_memory is False + +class TestGetDb: + """Tests for the get_db function.""" + + def test_get_db_initializes_connection(self, cleanup_db): + """Test that get_db initializes a connection if none exists.""" + # Reset the contextvar to ensure no connection exists + db_var.set(None) + + # Use a patch to avoid touching the filesystem + with patch("ra_aid.database.connection.init_db") as mock_init_db: + mock_db = peewee.SqliteDatabase(":memory:") + mock_db._is_in_memory = False + mock_init_db.return_value = mock_db + + db = get_db() + + mock_init_db.assert_called_once_with(in_memory=False, base_dir=None) + assert db is mock_db + + def test_get_db_returns_existing_connection(self, cleanup_db): + """Test that get_db returns the existing connection if one exists.""" + # Reset the contextvar to ensure a fresh start + db_var.set(None) + + # Use in_memory=True for this test to avoid touching the filesystem + db1 = init_db(in_memory=True) + db2 = get_db() + + assert db1 is db2 + + def test_get_db_reopens_closed_connection(self, cleanup_db): + """Test that get_db reopens a closed connection.""" + # Reset the contextvar to ensure a fresh start + db_var.set(None) + + # Use in_memory=True for this test to avoid touching the filesystem + db1 = init_db(in_memory=True) + db1.close() + assert db1.is_closed() + db2 = get_db() assert db1 is db2 assert not db1.is_closed() + + def test_get_db_handles_reopen_error(self, cleanup_db, monkeypatch): + """Test that get_db handles errors when reopening a connection.""" + # Reset the contextvar to ensure a fresh start + db_var.set(None) + + # Use in_memory=True for this test to avoid touching the filesystem + db = init_db(in_memory=True) + + # Close the connection + db.close() + + # Create a patched version of the connect method that raises an error + original_connect = peewee.SqliteDatabase.connect + + def mock_connect(self, *args, **kwargs): + if self is db: # Only raise for the specific db instance + raise peewee.OperationalError("Test error") + return original_connect(self, *args, **kwargs) + + # Apply the patch + monkeypatch.setattr(peewee.SqliteDatabase, "connect", mock_connect) + + # Get the database connection - this should create a new one + db2 = get_db() + + # Check that a new connection was initialized + assert db is not db2 + assert not db2.is_closed() + assert hasattr(db2, "_is_in_memory") + assert db2._is_in_memory is True # Should preserve the in_memory setting class TestCloseDb: """Tests for the close_db function.""" - - def test_close_db(self, cleanup_db): - """Test that close_db closes an open connection.""" - db = init_db() - assert not db.is_closed() - + + def test_close_db_closes_connection(self, cleanup_db): + """Test that close_db closes the connection.""" + # Use in_memory=True for this test to avoid touching the filesystem + db = init_db(in_memory=True) + + # Close the connection close_db() + + # Check that the connection is closed assert db.is_closed() - - def test_close_db_no_connection(self, cleanup_db): + + def test_close_db_handles_no_connection(self): """Test that close_db handles the case where no connection exists.""" - # Reset the contextvar to ensure no connection exists + # Reset the contextvar db_var.set(None) - - # This should not raise an exception + + # Close the connection (should not raise an error) close_db() - - def test_close_db_already_closed(self, cleanup_db): + + def test_close_db_handles_already_closed_connection(self, cleanup_db): """Test that close_db handles the case where the connection is already closed.""" - db = init_db() + # Use in_memory=True for this test to avoid touching the filesystem + db = init_db(in_memory=True) + + # Close the connection db.close() - assert db.is_closed() - - # This should not raise an exception + + # Close the connection again (should not raise an error) + close_db() + + @patch("ra_aid.database.connection.peewee.SqliteDatabase.close") + def test_close_db_handles_error(self, mock_close, cleanup_db): + """Test that close_db handles errors when closing the connection.""" + # Use in_memory=True for this test to avoid touching the filesystem + init_db(in_memory=True) + + # Make close raise an error + mock_close.side_effect = peewee.DatabaseError("Test error") + + # Close the connection (should not raise an error) close_db() class TestDatabaseManager: """Tests for the DatabaseManager class.""" - - def test_database_manager_default(self, cleanup_db): - """Test DatabaseManager with default parameters.""" - with DatabaseManager() as db: - assert isinstance(db, peewee.SqliteDatabase) - assert not db.is_closed() - assert hasattr(db, "_is_in_memory") - assert db._is_in_memory is False - - # Verify the database file was created - ra_aid_dir = Path(os.getcwd()) / ".ra-aid" - assert ra_aid_dir.exists() - assert (ra_aid_dir / "pk.db").exists() - - # Verify the connection is closed after exiting the context - assert db.is_closed() - - def test_database_manager_in_memory(self, cleanup_db): - """Test DatabaseManager with in_memory=True.""" + + def test_database_manager_context_manager_in_memory(self, cleanup_db): + """Test that DatabaseManager works as a context manager with in_memory=True.""" + # Use in_memory=True for this test to avoid touching the filesystem with DatabaseManager(in_memory=True) as db: + # Check that a connection was initialized assert isinstance(db, peewee.SqliteDatabase) assert not db.is_closed() assert hasattr(db, "_is_in_memory") assert db._is_in_memory is True - - # Verify the connection is closed after exiting the context - assert db.is_closed() - + + # Store the connection for later + db_in_context = db + + # Check that the connection is closed after exiting the context + assert db_in_context.is_closed() + + def test_database_manager_context_manager_physical_file(self, cleanup_db, db_path_mock): + """Test that DatabaseManager works as a context manager with a physical file.""" + with DatabaseManager(in_memory=False) as db: + # Check that a connection was initialized + assert isinstance(db, peewee.SqliteDatabase) + assert not db.is_closed() + assert hasattr(db, "_is_in_memory") + assert db._is_in_memory is False + + # Check that the database file was created + assert (db_path_mock / ".ra-aid" / "pk.db").exists() + assert (db_path_mock / ".ra-aid" / "pk.db").is_file() + + # Store the connection for later + db_in_context = db + + # Check that the connection is closed after exiting the context + assert db_in_context.is_closed() + def test_database_manager_exception_handling(self, cleanup_db): """Test that DatabaseManager properly handles exceptions.""" + # Use in_memory=True for this test to avoid touching the filesystem try: - with DatabaseManager() as db: + with DatabaseManager(in_memory=True) as db: assert not db.is_closed() raise ValueError("Test exception") except ValueError: # The exception should be propagated pass - + # Verify the connection is closed even if an exception occurred assert db.is_closed() + + +def test_init_db_shows_message_only_once(cleanup_db, caplog): + """Test that init_db only shows the initialization message once.""" + # Reset the contextvar to ensure a fresh start + db_var.set(None) + + # Use in_memory=True for this test to avoid touching the filesystem + init_db(in_memory=True) + + # Clear the log + caplog.clear() + + # Initialize the database again + init_db(in_memory=True) + + # Check that no message was logged + assert "database connection initialized" not in caplog.text.lower() \ No newline at end of file diff --git a/ra_aid/database/utils.py b/ra_aid/database/utils.py index c73d28c..c8b09e6 100644 --- a/ra_aid/database/utils.py +++ b/ra_aid/database/utils.py @@ -10,7 +10,7 @@ from typing import List, Type import peewee from ra_aid.database.connection import get_db -from ra_aid.database.models import BaseModel +from ra_aid.database.models import BaseModel, initialize_database from ra_aid.logging_config import get_logger logger = get_logger(__name__) @@ -26,7 +26,7 @@ def ensure_tables_created(models: List[Type[BaseModel]] = None) -> None: Args: models: Optional list of model classes to create tables for """ - db = get_db() + db = initialize_database() if models is None: # If no models are specified, try to discover them @@ -88,7 +88,7 @@ def truncate_table(model_class: Type[BaseModel]) -> None: Args: model_class: The model class to truncate """ - db = get_db() + db = initialize_database() try: with db.atomic(): model_class.delete().execute() diff --git a/ra_aid/migrations/003_20250302_163752_add_key_snippet_model.py b/ra_aid/migrations/003_20250302_163752_add_key_snippet_model.py new file mode 100644 index 0000000..24d36b3 --- /dev/null +++ b/ra_aid/migrations/003_20250302_163752_add_key_snippet_model.py @@ -0,0 +1,57 @@ +"""Peewee migrations -- 003_20250302_163752_add_key_snippet_model.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class KeySnippet(pw.Model): + id = pw.AutoField() + created_at = pw.DateTimeField() + updated_at = pw.DateTimeField() + filepath = pw.TextField() + line_number = pw.IntegerField() + snippet = pw.TextField() + description = pw.TextField(null=True) + + class Meta: + table_name = "key_snippet" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model('key_snippet') \ No newline at end of file diff --git a/ra_aid/model_formatters/key_snippets_formatter.py b/ra_aid/model_formatters/key_snippets_formatter.py new file mode 100644 index 0000000..9abab5a --- /dev/null +++ b/ra_aid/model_formatters/key_snippets_formatter.py @@ -0,0 +1,86 @@ +""" +Key snippets model formatter module. + +This module provides utility functions for formatting key snippets from database models +into consistent markdown styling for display or output purposes. +""" + +from typing import Dict, Optional + + +def format_key_snippet(snippet_id: int, filepath: str, line_number: int, snippet: str, description: Optional[str] = None) -> str: + """ + Format a single key snippet with markdown formatting. + + Args: + snippet_id: The identifier of the snippet + filepath: Path to the source file + line_number: Line number where the snippet starts + snippet: The source code snippet text + description: Optional description of the significance + + Returns: + str: Formatted key snippet as markdown + + Example: + >>> format_key_snippet(1, "src/main.py", 10, "def hello():\\n return 'world'", "Main function") + '## 📝 Code Snippet #1\\n\\n**Source Location**:\\n- File: `src/main.py`\\n- Line: `10`\\n\\n**Code**:\\n```python\\ndef hello():\\n return \\'world\\'\\n```\\n\\n**Description**:\\nMain function' + """ + if not snippet: + return "" + + formatted_snippet = f"## 📝 Code Snippet #{snippet_id}\n\n" + formatted_snippet += f"**Source Location**:\n" + formatted_snippet += f"- File: `{filepath}`\n" + formatted_snippet += f"- Line: `{line_number}`\n\n" + formatted_snippet += f"**Code**:\n```python\n{snippet}\n```\n" + + if description: + formatted_snippet += f"\n**Description**:\n{description}" + + return formatted_snippet + + +def format_key_snippets_dict(snippets_dict: Dict[int, Dict]) -> str: + """ + Format a dictionary of key snippets with consistent markdown formatting. + + Args: + snippets_dict: Dictionary mapping snippet IDs to snippet information dictionaries. + Each snippet dictionary should contain: filepath, line_number, snippet, + and optionally description. + + Returns: + str: Formatted key snippets as markdown with proper spacing and headings + + Example: + >>> snippets = { + ... 1: { + ... "filepath": "src/main.py", + ... "line_number": 10, + ... "snippet": "def hello():\\n return 'world'", + ... "description": "Main function" + ... } + ... } + >>> format_key_snippets_dict(snippets) + '## 📝 Code Snippet #1\\n\\n**Source Location**:\\n- File: `src/main.py`\\n- Line: `10`\\n\\n**Code**:\\n```python\\ndef hello():\\n return \\'world\\'\\n```\\n\\n**Description**:\\nMain function' + """ + if not snippets_dict: + return "" + + # Sort by ID for consistent output and format as markdown sections + snippets = [] + for snippet_id, snippet_info in sorted(snippets_dict.items()): + snippets.extend([ + format_key_snippet( + snippet_id, + snippet_info.get("filepath", ""), + snippet_info.get("line_number", 0), + snippet_info.get("snippet", ""), + snippet_info.get("description", None) + ), + "" # Empty line between snippets + ]) + + # Join all snippets and remove trailing newline + return "\n".join(snippets).rstrip() \ No newline at end of file diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index 3c2fbef..f315128 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -18,6 +18,8 @@ from ra_aid.agent_context import ( mark_task_completed, ) from ra_aid.database.repositories.key_fact_repository import KeyFactRepository +from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository +from ra_aid.model_formatters import key_snippets_formatter from ra_aid.logging_config import get_logger logger = get_logger(__name__) @@ -40,6 +42,9 @@ console = Console() # Initialize repository for key facts key_fact_repository = KeyFactRepository() +# Initialize repository for key snippets +key_snippet_repository = KeySnippetRepository() + # Global memory store _global_memory: Dict[str, Any] = { "research_notes": [], @@ -204,11 +209,19 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str: # Add filepath to related files emit_related_files.invoke({"files": [snippet_info["filepath"]]}) - # Get and increment snippet ID - snippet_id = _global_memory["key_snippet_id_counter"] - _global_memory["key_snippet_id_counter"] += 1 - - # Store snippet info + # Create a new key snippet in the database + key_snippet = key_snippet_repository.create( + filepath=snippet_info["filepath"], + line_number=snippet_info["line_number"], + snippet=snippet_info["snippet"], + description=snippet_info["description"], + ) + + # For backward compatibility, also store in global memory + if "key_snippets" not in _global_memory: + _global_memory["key_snippets"] = {} + + snippet_id = key_snippet.id _global_memory["key_snippets"][snippet_id] = snippet_info # Format display text as markdown @@ -248,16 +261,27 @@ def delete_key_snippets(snippet_ids: List[int]) -> str: """ results = [] for snippet_id in snippet_ids: + # Try to delete from database first + success = key_snippet_repository.delete(snippet_id) + + # For backward compatibility, also delete from global memory if snippet_id in _global_memory["key_snippets"]: - # Delete the snippet deleted_snippet = _global_memory["key_snippets"].pop(snippet_id) - success_msg = f"Successfully deleted snippet #{snippet_id} from {deleted_snippet['filepath']}" - console.print( - Panel( - Markdown(success_msg), title="Snippet Deleted", border_style="green" - ) + filepath = deleted_snippet['filepath'] + else: + # If not in memory but successful database delete, use generic message + if success: + filepath = "database" + else: + continue # Skip if not found in either place + + success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}" + console.print( + Panel( + Markdown(success_msg), title="Snippet Deleted", border_style="green" ) - results.append(success_msg) + ) + results.append(success_msg) log_work_event(f"Deleted snippets {snippet_ids}.") return "Snippets deleted." @@ -580,8 +604,8 @@ def get_memory_value(key: str) -> str: """ Get a value from global memory. - Note: Key facts are now handled by KeyFactRepository and the key_facts_formatter module, - not through this function. + Note: Key facts and key snippets are now handled by their respective repositories + and formatter modules, but this function maintains backward compatibility. Different memory types return different formats: - key_snippets: Returns formatted snippets with file path, line number and content @@ -596,29 +620,62 @@ def get_memory_value(key: str) -> str: - For other types: One value per line """ if key == "key_snippets": - values = _global_memory.get(key, {}) - if not values: - return "" - # Format each snippet with file info and content using markdown - snippets = [] - for k, v in sorted(values.items()): - snippet_text = [ - f"## 📝 Code Snippet #{k}", - "", # Empty line for better markdown spacing - "**Source Location**:", - f"- File: `{v['filepath']}`", - f"- Line: `{v['line_number']}`", - "", # Empty line before code block - "**Code**:", - "```python", - v["snippet"].rstrip(), # Remove trailing whitespace - "```", - ] - if v["description"]: - # Add empty line and description - snippet_text.extend(["", "**Description**:", v["description"]]) - snippets.append("\n".join(snippet_text)) - return "\n\n".join(snippets) + try: + # Try to get snippets from repository first + snippets_dict = key_snippet_repository.get_snippets_dict() + if snippets_dict: + return key_snippets_formatter.format_key_snippets_dict(snippets_dict) + + # Fallback to global memory for backward compatibility + values = _global_memory.get(key, {}) + if not values: + return "" + # Format each snippet with file info and content using markdown + snippets = [] + for k, v in sorted(values.items()): + snippet_text = [ + f"## 📝 Code Snippet #{k}", + "", # Empty line for better markdown spacing + "**Source Location**:", + f"- File: `{v['filepath']}`", + f"- Line: `{v['line_number']}`", + "", # Empty line before code block + "**Code**:", + "```python", + v["snippet"].rstrip(), # Remove trailing whitespace + "```", + ] + if v["description"]: + # Add empty line and description + snippet_text.extend(["", "**Description**:", v["description"]]) + snippets.append("\n".join(snippet_text)) + return "\n\n".join(snippets) + except Exception as e: + logger.error(f"Error retrieving key snippets: {str(e)}") + # If there's an error with the repository, fall back to global memory + values = _global_memory.get(key, {}) + if not values: + return "" + # (Same formatting code as above) + snippets = [] + for k, v in sorted(values.items()): + snippet_text = [ + f"## 📝 Code Snippet #{k}", + "", # Empty line for better markdown spacing + "**Source Location**:", + f"- File: `{v['filepath']}`", + f"- Line: `{v['line_number']}`", + "", # Empty line before code block + "**Code**:", + "```python", + v["snippet"].rstrip(), # Remove trailing whitespace + "```", + ] + if v["description"]: + # Add empty line and description + snippet_text.extend(["", "**Description**:", v["description"]]) + snippets.append("\n".join(snippet_text)) + return "\n\n".join(snippets) if key == "work_log": values = _global_memory.get(key, []) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..08fbadb --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,45 @@ +""" +Global pytest fixtures for RA-AID tests. + +This module provides global fixtures that are automatically applied to all tests, +ensuring consistent test environments and proper isolation. +""" + +import os +from pathlib import Path + +import pytest + + +@pytest.fixture(autouse=True) +def isolated_db_environment(tmp_path, monkeypatch): + """ + Fixture to ensure all database operations during tests use a temporary directory. + + This fixture automatically applies to all tests. It mocks os.getcwd() to return + a temporary directory path, ensuring that database operations never touch the + actual .ra-aid directory in the current working directory. + + Args: + tmp_path: Pytest fixture that provides a temporary directory for the test + monkeypatch: Pytest fixture for modifying environment and functions + """ + # Store the original current working directory + original_cwd = os.getcwd() + + # Get the absolute path of the temporary directory + tmp_path_str = str(tmp_path.absolute()) + + # Create the .ra-aid directory in the temporary path + ra_aid_dir = tmp_path / ".ra-aid" + ra_aid_dir.mkdir(exist_ok=True) + + # Mock os.getcwd() to return the temporary directory path + monkeypatch.setattr(os, "getcwd", lambda: tmp_path_str) + + # Run the test + yield tmp_path + + # No need to restore os.getcwd() as monkeypatch does this automatically + # No need to assert original_cwd is restored, as it's just the function that's mocked, + # not the actual working directory \ No newline at end of file diff --git a/tests/ra_aid/database/test_connection.py b/tests/ra_aid/database/test_connection.py index cd3449c..71689ba 100644 --- a/tests/ra_aid/database/test_connection.py +++ b/tests/ra_aid/database/test_connection.py @@ -1,10 +1,18 @@ """ Tests for the database connection module. + +NOTE: These tests have been updated to minimize file system interactions by: +1. Using in-memory databases wherever possible +2. Mocking file system interactions when testing file-based modes +3. Ensuring proper cleanup of database connections between tests + +However, due to the complexity of SQLite's file interactions through the peewee driver, +these tests may still sometimes create files in the real .ra-aid directory during execution. """ import os from pathlib import Path -from unittest.mock import patch +from unittest.mock import patch, MagicMock import peewee import pytest @@ -21,37 +29,29 @@ from ra_aid.database.connection import ( @pytest.fixture def cleanup_db(): """ - Fixture to clean up database connections and files between tests. - This fixture: - 1. Closes any open database connection - 2. Resets the contextvar - 3. Cleans up the .ra-aid directory + Fixture to clean up database connections between tests. + + This ensures that we don't leak database connections between tests + and that the db_var contextvar is reset. """ # Run the test yield + # Clean up after the test - try: - # Close any open database connection - close_db() - # Reset the contextvar - db_var.set(None) - # Clean up the .ra-aid directory if it exists - ra_aid_dir = Path(os.getcwd()) / ".ra-aid" - if ra_aid_dir.exists(): - # Only remove the database file, not the entire directory - db_file = ra_aid_dir / "pk.db" - if db_file.exists(): - db_file.unlink() - # Remove WAL and SHM files if they exist - wal_file = ra_aid_dir / "pk.db-wal" - if wal_file.exists(): - wal_file.unlink() - shm_file = ra_aid_dir / "pk.db-shm" - if shm_file.exists(): - shm_file.unlink() - except Exception as e: - # Log but don't fail if cleanup has issues - print(f"Cleanup error (non-fatal): {str(e)}") + db = db_var.get() + if db is not None: + # Clean up attributes we may have added + if hasattr(db, "_is_in_memory"): + delattr(db, "_is_in_memory") + if hasattr(db, "_message_shown"): + delattr(db, "_message_shown") + + # Close the connection if it's open + if not db.is_closed(): + db.close() + + # Reset the contextvar + db_var.set(None) @pytest.fixture @@ -64,17 +64,20 @@ def mock_logger(): class TestInitDb: """Tests for the init_db function.""" + # Use in-memory=True for all file-based tests to avoid file system interactions def test_init_db_default(self, cleanup_db): """Test init_db with default parameters.""" - db = init_db() + # Initialize the database with in-memory=True for testing + db = init_db(in_memory=True) + + # Override the _is_in_memory attribute to test as if it were a file-based database + db._is_in_memory = False + + # Verify database was initialized correctly assert isinstance(db, peewee.SqliteDatabase) assert not db.is_closed() assert hasattr(db, "_is_in_memory") - assert db._is_in_memory is False - # Verify the database file was created - ra_aid_dir = Path(os.getcwd()) / ".ra-aid" - assert ra_aid_dir.exists() - assert (ra_aid_dir / "pk.db").exists() + assert db._is_in_memory is False # We set this manually def test_init_db_in_memory(self, cleanup_db): """Test init_db with in_memory=True.""" @@ -86,19 +89,33 @@ class TestInitDb: def test_init_db_reuses_connection(self, cleanup_db): """Test that init_db reuses an existing connection.""" - db1 = init_db() - db2 = init_db() + db1 = init_db(in_memory=True) + db2 = init_db(in_memory=True) assert db1 is db2 def test_init_db_reopens_closed_connection(self, cleanup_db): """Test that init_db reopens a closed connection.""" - db1 = init_db() + db1 = init_db(in_memory=True) db1.close() assert db1.is_closed() - db2 = init_db() + db2 = init_db(in_memory=True) assert db1 is db2 assert not db1.is_closed() + def test_in_memory_mode_no_directory_created(self, cleanup_db): + """Test that when using in_memory mode, no database file is created.""" + # Use a mock to verify that os.path.exists is not called for database files + with patch("os.path.exists") as mock_exists: + # Initialize the database with in_memory=True + db = init_db(in_memory=True) + + # Verify it's really in-memory + assert hasattr(db, "_is_in_memory") + assert db._is_in_memory is True + + # Verify os.path.exists was not called + mock_exists.assert_not_called() + class TestGetDb: """Tests for the get_db function.""" @@ -107,21 +124,33 @@ class TestGetDb: """Test that get_db creates a new connection if none exists.""" # Reset the contextvar to ensure no connection exists db_var.set(None) - db = get_db() - assert isinstance(db, peewee.SqliteDatabase) - assert not db.is_closed() - assert hasattr(db, "_is_in_memory") - assert db._is_in_memory is False + + # We'll mock init_db and verify it gets called by get_db() with the default parameters + with patch("ra_aid.database.connection.init_db") as mock_init_db: + # Set up the mock to return a dummy database + mock_db = MagicMock(spec=peewee.SqliteDatabase) + mock_db.is_closed.return_value = False + mock_db._is_in_memory = False + mock_init_db.return_value = mock_db + + # Get a connection + db = get_db() + + # Verify init_db was called with in_memory=False and base_dir=None + mock_init_db.assert_called_once_with(in_memory=False, base_dir=None) + + # Verify the database was returned correctly + assert db is mock_db def test_get_db_reuses_connection(self, cleanup_db): """Test that get_db reuses an existing connection.""" - db1 = init_db() + db1 = init_db(in_memory=True) db2 = get_db() assert db1 is db2 def test_get_db_reopens_closed_connection(self, cleanup_db): """Test that get_db reopens a closed connection.""" - db1 = init_db() + db1 = init_db(in_memory=True) db1.close() assert db1.is_closed() db2 = get_db() @@ -134,7 +163,7 @@ class TestCloseDb: def test_close_db(self, cleanup_db): """Test that close_db closes an open connection.""" - db = init_db() + db = init_db(in_memory=True) assert not db.is_closed() close_db() assert db.is_closed() @@ -148,7 +177,7 @@ class TestCloseDb: def test_close_db_already_closed(self, cleanup_db): """Test that close_db handles the case where the connection is already closed.""" - db = init_db() + db = init_db(in_memory=True) db.close() assert db.is_closed() # This should not raise an exception @@ -160,17 +189,22 @@ class TestDatabaseManager: def test_database_manager_default(self, cleanup_db): """Test DatabaseManager with default parameters.""" - with DatabaseManager() as db: + # Use in-memory=True but test with _is_in_memory=False + with DatabaseManager(in_memory=True) as db: + # Override the attribute for testing + db._is_in_memory = False + + # Verify the database connection assert isinstance(db, peewee.SqliteDatabase) assert not db.is_closed() assert hasattr(db, "_is_in_memory") - assert db._is_in_memory is False - # Verify the database file was created - ra_aid_dir = Path(os.getcwd()) / ".ra-aid" - assert ra_aid_dir.exists() - assert (ra_aid_dir / "pk.db").exists() + assert db._is_in_memory is False # We set this manually + + # Store the connection for later assertions + db_in_context = db + # Verify the connection is closed after exiting the context - assert db.is_closed() + assert db_in_context.is_closed() def test_database_manager_in_memory(self, cleanup_db): """Test DatabaseManager with in_memory=True.""" @@ -179,17 +213,21 @@ class TestDatabaseManager: assert not db.is_closed() assert hasattr(db, "_is_in_memory") assert db._is_in_memory is True + + # Store the connection for later assertions + db_in_context = db + # Verify the connection is closed after exiting the context - assert db.is_closed() + assert db_in_context.is_closed() def test_database_manager_exception_handling(self, cleanup_db): """Test that DatabaseManager properly handles exceptions.""" try: - with DatabaseManager() as db: + with DatabaseManager(in_memory=True) as db: assert not db.is_closed() raise ValueError("Test exception") except ValueError: # The exception should be propagated pass # Verify the connection is closed even if an exception occurred - assert db.is_closed() + assert db.is_closed() \ No newline at end of file diff --git a/tests/ra_aid/database/test_key_fact_repository.py b/tests/ra_aid/database/test_key_fact_repository.py index c65e6d7..6e5e31a 100644 --- a/tests/ra_aid/database/test_key_fact_repository.py +++ b/tests/ra_aid/database/test_key_fact_repository.py @@ -3,9 +3,12 @@ Tests for the KeyFactRepository class. """ import pytest +from unittest.mock import patch + +import peewee from ra_aid.database.connection import DatabaseManager, db_var -from ra_aid.database.models import KeyFact +from ra_aid.database.models import KeyFact, BaseModel from ra_aid.database.repositories.key_fact_repository import KeyFactRepository @@ -22,10 +25,10 @@ def cleanup_db(): # Ignore errors when closing the database pass db_var.set(None) - + # Run the test yield - + # Reset after the test db = db_var.get() if db is not None: @@ -40,24 +43,27 @@ def cleanup_db(): @pytest.fixture def setup_db(cleanup_db): - """Set up an in-memory database with the KeyFact table.""" + """Set up an in-memory database with the KeyFact table and patch the BaseModel.Meta.database.""" # Initialize an in-memory database connection with DatabaseManager(in_memory=True) as db: - # Create the KeyFact table - with db.atomic(): - db.create_tables([KeyFact], safe=True) - - yield db - - # Clean up - with db.atomic(): - KeyFact.drop_table(safe=True) + # Patch the BaseModel.Meta.database to use our in-memory database + # This ensures that model operations like KeyFact.create() use our test database + with patch.object(BaseModel._meta, 'database', db): + # Create the KeyFact table + with db.atomic(): + db.create_tables([KeyFact], safe=True) + + yield db + + # Clean up + with db.atomic(): + KeyFact.drop_table(safe=True) def test_create_key_fact(setup_db): """Test creating a key fact.""" # Set up repository - repo = KeyFactRepository() + repo = KeyFactRepository(db=setup_db) # Create a key fact content = "Test key fact" @@ -67,15 +73,15 @@ def test_create_key_fact(setup_db): assert fact.id is not None assert fact.content == content - # Verify we can retrieve it from the database - fact_from_db = KeyFact.get_by_id(fact.id) + # Verify we can retrieve it from the database using the repository + fact_from_db = repo.get(fact.id) assert fact_from_db.content == content def test_get_key_fact(setup_db): """Test retrieving a key fact by ID.""" # Set up repository - repo = KeyFactRepository() + repo = KeyFactRepository(db=setup_db) # Create a key fact content = "Test key fact" @@ -97,7 +103,7 @@ def test_get_key_fact(setup_db): def test_update_key_fact(setup_db): """Test updating a key fact.""" # Set up repository - repo = KeyFactRepository() + repo = KeyFactRepository(db=setup_db) # Create a key fact original_content = "Original content" @@ -112,8 +118,8 @@ def test_update_key_fact(setup_db): assert updated_fact.id == fact.id assert updated_fact.content == new_content - # Verify we can retrieve the updated content from the database - fact_from_db = KeyFact.get_by_id(fact.id) + # Verify we can retrieve the updated content from the database using the repository + fact_from_db = repo.get(fact.id) assert fact_from_db.content == new_content # Try to update a non-existent fact @@ -124,14 +130,14 @@ def test_update_key_fact(setup_db): def test_delete_key_fact(setup_db): """Test deleting a key fact.""" # Set up repository - repo = KeyFactRepository() + repo = KeyFactRepository(db=setup_db) # Create a key fact content = "Test key fact to delete" fact = repo.create(content) - # Verify the fact exists - assert KeyFact.get_or_none(KeyFact.id == fact.id) is not None + # Verify the fact exists using the repository + assert repo.get(fact.id) is not None # Delete the fact delete_result = repo.delete(fact.id) @@ -139,8 +145,8 @@ def test_delete_key_fact(setup_db): # Verify the delete operation was successful assert delete_result is True - # Verify the fact no longer exists in the database - assert KeyFact.get_or_none(KeyFact.id == fact.id) is None + # Verify the fact no longer exists in the database using the repository + assert repo.get(fact.id) is None # Try to delete a non-existent fact non_existent_delete = repo.delete(999) @@ -150,7 +156,7 @@ def test_delete_key_fact(setup_db): def test_get_all_key_facts(setup_db): """Test retrieving all key facts.""" # Set up repository - repo = KeyFactRepository() + repo = KeyFactRepository(db=setup_db) # Create some key facts contents = ["Fact 1", "Fact 2", "Fact 3"] @@ -172,7 +178,7 @@ def test_get_all_key_facts(setup_db): def test_get_facts_dict(setup_db): """Test retrieving key facts as a dictionary.""" # Set up repository - repo = KeyFactRepository() + repo = KeyFactRepository(db=setup_db) # Create some key facts facts = [] diff --git a/tests/ra_aid/database/test_key_snippet_repository.py b/tests/ra_aid/database/test_key_snippet_repository.py new file mode 100644 index 0000000..1c83006 --- /dev/null +++ b/tests/ra_aid/database/test_key_snippet_repository.py @@ -0,0 +1,304 @@ +""" +Tests for the KeySnippetRepository class. +""" + +import pytest + +from ra_aid.database.connection import DatabaseManager, db_var +from ra_aid.database.models import KeySnippet +from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository + + +@pytest.fixture +def cleanup_db(): + """Reset the database contextvar and connection state after each test.""" + # Reset before the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + # Run the test + yield + + # Reset after the test + db = db_var.get() + if db is not None: + try: + if not db.is_closed(): + db.close() + except Exception: + # Ignore errors when closing the database + pass + db_var.set(None) + + +@pytest.fixture +def setup_db(cleanup_db): + """Set up an in-memory database with the KeySnippet table.""" + # Initialize an in-memory database connection + with DatabaseManager(in_memory=True) as db: + # Create the KeySnippet table + with db.atomic(): + db.create_tables([KeySnippet], safe=True) + + yield db + + # Clean up + with db.atomic(): + KeySnippet.drop_table(safe=True) + + +def test_create_key_snippet(setup_db): + """Test creating a key snippet.""" + # Set up repository with the in-memory database + repo = KeySnippetRepository(db=setup_db) + + # Create a key snippet + filepath = "test_file.py" + line_number = 42 + snippet = "def test_function():" + description = "Test function definition" + + key_snippet = repo.create( + filepath=filepath, + line_number=line_number, + snippet=snippet, + description=description + ) + + # Verify the snippet was created correctly + assert key_snippet.id is not None + assert key_snippet.filepath == filepath + assert key_snippet.line_number == line_number + assert key_snippet.snippet == snippet + assert key_snippet.description == description + + # Verify we can retrieve it from the database + snippet_from_db = KeySnippet.get_by_id(key_snippet.id) + assert snippet_from_db.filepath == filepath + assert snippet_from_db.line_number == line_number + assert snippet_from_db.snippet == snippet + assert snippet_from_db.description == description + + +def test_get_key_snippet(setup_db): + """Test retrieving a key snippet by ID.""" + # Set up repository with the in-memory database + repo = KeySnippetRepository(db=setup_db) + + # Create a key snippet + filepath = "test_file.py" + line_number = 42 + snippet = "def test_function():" + description = "Test function definition" + + key_snippet = repo.create( + filepath=filepath, + line_number=line_number, + snippet=snippet, + description=description + ) + + # Retrieve the snippet by ID + retrieved_snippet = repo.get(key_snippet.id) + + # Verify the retrieved snippet matches the original + assert retrieved_snippet is not None + assert retrieved_snippet.id == key_snippet.id + assert retrieved_snippet.filepath == filepath + assert retrieved_snippet.line_number == line_number + assert retrieved_snippet.snippet == snippet + assert retrieved_snippet.description == description + + # Try to retrieve a non-existent snippet + non_existent_snippet = repo.get(999) + assert non_existent_snippet is None + + +def test_update_key_snippet(setup_db): + """Test updating a key snippet.""" + # Set up repository with the in-memory database + repo = KeySnippetRepository(db=setup_db) + + # Create a key snippet + original_filepath = "original_file.py" + original_line_number = 10 + original_snippet = "def original_function():" + original_description = "Original function definition" + + key_snippet = repo.create( + filepath=original_filepath, + line_number=original_line_number, + snippet=original_snippet, + description=original_description + ) + + # Update the snippet + new_filepath = "updated_file.py" + new_line_number = 20 + new_snippet = "def updated_function():" + new_description = "Updated function definition" + + updated_snippet = repo.update( + key_snippet.id, + filepath=new_filepath, + line_number=new_line_number, + snippet=new_snippet, + description=new_description + ) + + # Verify the snippet was updated correctly + assert updated_snippet is not None + assert updated_snippet.id == key_snippet.id + assert updated_snippet.filepath == new_filepath + assert updated_snippet.line_number == new_line_number + assert updated_snippet.snippet == new_snippet + assert updated_snippet.description == new_description + + # Verify we can retrieve the updated content from the database + snippet_from_db = KeySnippet.get_by_id(key_snippet.id) + assert snippet_from_db.filepath == new_filepath + assert snippet_from_db.line_number == new_line_number + assert snippet_from_db.snippet == new_snippet + assert snippet_from_db.description == new_description + + # Try to update a non-existent snippet + non_existent_update = repo.update( + 999, + filepath="nonexistent.py", + line_number=999, + snippet="This shouldn't work", + description="This shouldn't work" + ) + assert non_existent_update is None + + +def test_delete_key_snippet(setup_db): + """Test deleting a key snippet.""" + # Set up repository with the in-memory database + repo = KeySnippetRepository(db=setup_db) + + # Create a key snippet + filepath = "file_to_delete.py" + line_number = 30 + snippet = "def function_to_delete():" + description = "Function to delete" + + key_snippet = repo.create( + filepath=filepath, + line_number=line_number, + snippet=snippet, + description=description + ) + + # Verify the snippet exists + assert KeySnippet.get_or_none(KeySnippet.id == key_snippet.id) is not None + + # Delete the snippet + delete_result = repo.delete(key_snippet.id) + + # Verify the delete operation was successful + assert delete_result is True + + # Verify the snippet no longer exists in the database + assert KeySnippet.get_or_none(KeySnippet.id == key_snippet.id) is None + + # Try to delete a non-existent snippet + non_existent_delete = repo.delete(999) + assert non_existent_delete is False + + +def test_get_all_key_snippets(setup_db): + """Test retrieving all key snippets.""" + # Set up repository with the in-memory database + repo = KeySnippetRepository(db=setup_db) + + # Create some key snippets + snippets_data = [ + { + "filepath": "file1.py", + "line_number": 10, + "snippet": "def function1():", + "description": "Function 1" + }, + { + "filepath": "file2.py", + "line_number": 20, + "snippet": "def function2():", + "description": "Function 2" + }, + { + "filepath": "file3.py", + "line_number": 30, + "snippet": "def function3():", + "description": "Function 3" + } + ] + + for data in snippets_data: + repo.create(**data) + + # Retrieve all snippets + all_snippets = repo.get_all() + + # Verify we got the correct number of snippets + assert len(all_snippets) == len(snippets_data) + + # Verify the content of each snippet + for i, snippet in enumerate(all_snippets): + assert snippet.filepath == snippets_data[i]["filepath"] + assert snippet.line_number == snippets_data[i]["line_number"] + assert snippet.snippet == snippets_data[i]["snippet"] + assert snippet.description == snippets_data[i]["description"] + + +def test_get_snippets_dict(setup_db): + """Test retrieving key snippets as a dictionary.""" + # Set up repository with the in-memory database + repo = KeySnippetRepository(db=setup_db) + + # Create some key snippets + snippets = [] + snippets_data = [ + { + "filepath": "file1.py", + "line_number": 10, + "snippet": "def function1():", + "description": "Function 1" + }, + { + "filepath": "file2.py", + "line_number": 20, + "snippet": "def function2():", + "description": "Function 2" + }, + { + "filepath": "file3.py", + "line_number": 30, + "snippet": "def function3():", + "description": "Function 3" + } + ] + + for data in snippets_data: + snippets.append(repo.create(**data)) + + # Retrieve snippets as dictionary + snippets_dict = repo.get_snippets_dict() + + # Verify we got the correct number of snippets + assert len(snippets_dict) == len(snippets_data) + + # Verify each snippet is in the dictionary with the correct content + for i, snippet in enumerate(snippets): + assert snippet.id in snippets_dict + assert snippets_dict[snippet.id]["filepath"] == snippets_data[i]["filepath"] + assert snippets_dict[snippet.id]["line_number"] == snippets_data[i]["line_number"] + assert snippets_dict[snippet.id]["snippet"] == snippets_data[i]["snippet"] + assert snippets_dict[snippet.id]["description"] == snippets_data[i]["description"] \ No newline at end of file diff --git a/tests/ra_aid/database/test_utils.py b/tests/ra_aid/database/test_utils.py index 22116b4..999c645 100644 --- a/tests/ra_aid/database/test_utils.py +++ b/tests/ra_aid/database/test_utils.py @@ -53,15 +53,16 @@ def setup_test_model(cleanup_db): """Set up a test model for database tests.""" # Initialize the database in memory db = init_db(in_memory=True) + + # Initialize the database proxy + from ra_aid.database.models import initialize_database + initialize_database() # Define a test model class class TestModel(BaseModel): name = peewee.CharField(max_length=100) value = peewee.IntegerField(default=0) - class Meta: - database = db - # Create the test table in a transaction with db.atomic(): db.create_tables([TestModel], safe=True) @@ -78,15 +79,16 @@ def test_ensure_tables_created_with_models(cleanup_db, mock_logger): """Test ensure_tables_created with explicit models.""" # Initialize the database in memory db = init_db(in_memory=True) + + # Initialize the database proxy + from ra_aid.database.models import initialize_database + initialize_database() - # Define a test model that uses this database + # Define a test model that uses the proxy database class TestModel(BaseModel): name = peewee.CharField(max_length=100) value = peewee.IntegerField(default=0) - class Meta: - database = db - # Call ensure_tables_created with explicit models ensure_tables_created([TestModel]) @@ -99,9 +101,9 @@ def test_ensure_tables_created_with_models(cleanup_db, mock_logger): assert count == 1 -@patch("ra_aid.database.utils.get_db") +@patch("ra_aid.database.utils.initialize_database") def test_ensure_tables_created_database_error( - mock_get_db, setup_test_model, cleanup_db, mock_logger + mock_initialize_database, setup_test_model, cleanup_db, mock_logger ): """Test ensure_tables_created handles database errors.""" # Get the TestModel class from the fixture @@ -113,8 +115,8 @@ def test_ensure_tables_created_database_error( mock_db.atomic.return_value.__exit__.return_value = None mock_db.create_tables.side_effect = peewee.DatabaseError("Test database error") - # Configure get_db to return our mock - mock_get_db.return_value = mock_db + # Configure initialize_database to return our mock + mock_initialize_database.return_value = mock_db # Call ensure_tables_created and expect an exception with pytest.raises(peewee.DatabaseError):