""" Tests for the KeySnippetRepository class. """ import pytest from ra_aid.database.connection import DatabaseManager, db_var from ra_aid.database.models import KeySnippet from ra_aid.database.pydantic_models import KeySnippetModel from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository @pytest.fixture def cleanup_db(): """Reset the database contextvar and connection state after each test.""" # Reset before the test db = db_var.get() if db is not None: try: if not db.is_closed(): db.close() except Exception: # Ignore errors when closing the database pass db_var.set(None) # Run the test yield # Reset after the test db = db_var.get() if db is not None: try: if not db.is_closed(): db.close() except Exception: # Ignore errors when closing the database pass db_var.set(None) @pytest.fixture def setup_db(cleanup_db): """Set up an in-memory database with the KeySnippet table.""" # Initialize an in-memory database connection with DatabaseManager(in_memory=True) as db: # Create the KeySnippet table with db.atomic(): db.create_tables([KeySnippet], safe=True) yield db # Clean up with db.atomic(): KeySnippet.drop_table(safe=True) def test_create_key_snippet(setup_db): """Test creating a key snippet.""" # Set up repository with the in-memory database repo = KeySnippetRepository(db=setup_db) # Create a key snippet filepath = "test_file.py" line_number = 42 snippet = "def test_function():" description = "Test function definition" key_snippet = repo.create( filepath=filepath, line_number=line_number, snippet=snippet, description=description ) # Verify the snippet was created correctly assert key_snippet.id is not None assert key_snippet.filepath == filepath assert key_snippet.line_number == line_number assert key_snippet.snippet == snippet assert key_snippet.description == description # Verify the return type is a Pydantic model assert isinstance(key_snippet, KeySnippetModel) # Verify we can retrieve it from the database snippet_from_db = KeySnippet.get_by_id(key_snippet.id) assert snippet_from_db.filepath == filepath assert snippet_from_db.line_number == line_number assert snippet_from_db.snippet == snippet assert snippet_from_db.description == description def test_get_key_snippet(setup_db): """Test retrieving a key snippet by ID.""" # Set up repository with the in-memory database repo = KeySnippetRepository(db=setup_db) # Create a key snippet filepath = "test_file.py" line_number = 42 snippet = "def test_function():" description = "Test function definition" key_snippet = repo.create( filepath=filepath, line_number=line_number, snippet=snippet, description=description ) # Retrieve the snippet by ID retrieved_snippet = repo.get(key_snippet.id) # Verify the retrieved snippet matches the original assert retrieved_snippet is not None assert retrieved_snippet.id == key_snippet.id assert retrieved_snippet.filepath == filepath assert retrieved_snippet.line_number == line_number assert retrieved_snippet.snippet == snippet assert retrieved_snippet.description == description # Verify the return type is a Pydantic model assert isinstance(retrieved_snippet, KeySnippetModel) # Try to retrieve a non-existent snippet non_existent_snippet = repo.get(999) assert non_existent_snippet is None def test_update_key_snippet(setup_db): """Test updating a key snippet.""" # Set up repository with the in-memory database repo = KeySnippetRepository(db=setup_db) # Create a key snippet original_filepath = "original_file.py" original_line_number = 10 original_snippet = "def original_function():" original_description = "Original function definition" key_snippet = repo.create( filepath=original_filepath, line_number=original_line_number, snippet=original_snippet, description=original_description ) # Update the snippet new_filepath = "updated_file.py" new_line_number = 20 new_snippet = "def updated_function():" new_description = "Updated function definition" updated_snippet = repo.update( key_snippet.id, filepath=new_filepath, line_number=new_line_number, snippet=new_snippet, description=new_description ) # Verify the snippet was updated correctly assert updated_snippet is not None assert updated_snippet.id == key_snippet.id assert updated_snippet.filepath == new_filepath assert updated_snippet.line_number == new_line_number assert updated_snippet.snippet == new_snippet assert updated_snippet.description == new_description # Verify the return type is a Pydantic model assert isinstance(updated_snippet, KeySnippetModel) # Verify we can retrieve the updated content from the database snippet_from_db = KeySnippet.get_by_id(key_snippet.id) assert snippet_from_db.filepath == new_filepath assert snippet_from_db.line_number == new_line_number assert snippet_from_db.snippet == new_snippet assert snippet_from_db.description == new_description # Try to update a non-existent snippet non_existent_update = repo.update( 999, filepath="nonexistent.py", line_number=999, snippet="This shouldn't work", description="This shouldn't work" ) assert non_existent_update is None def test_delete_key_snippet(setup_db): """Test deleting a key snippet.""" # Set up repository with the in-memory database repo = KeySnippetRepository(db=setup_db) # Create a key snippet filepath = "file_to_delete.py" line_number = 30 snippet = "def function_to_delete():" description = "Function to delete" key_snippet = repo.create( filepath=filepath, line_number=line_number, snippet=snippet, description=description ) # Verify the snippet exists assert KeySnippet.get_or_none(KeySnippet.id == key_snippet.id) is not None # Delete the snippet delete_result = repo.delete(key_snippet.id) # Verify the delete operation was successful assert delete_result is True # Verify the snippet no longer exists in the database assert KeySnippet.get_or_none(KeySnippet.id == key_snippet.id) is None # Try to delete a non-existent snippet non_existent_delete = repo.delete(999) assert non_existent_delete is False def test_get_all_key_snippets(setup_db): """Test retrieving all key snippets.""" # Set up repository with the in-memory database repo = KeySnippetRepository(db=setup_db) # Create some key snippets snippets_data = [ { "filepath": "file1.py", "line_number": 10, "snippet": "def function1():", "description": "Function 1" }, { "filepath": "file2.py", "line_number": 20, "snippet": "def function2():", "description": "Function 2" }, { "filepath": "file3.py", "line_number": 30, "snippet": "def function3():", "description": "Function 3" } ] for data in snippets_data: repo.create(**data) # Retrieve all snippets all_snippets = repo.get_all() # Verify we got the correct number of snippets assert len(all_snippets) == len(snippets_data) # Verify all returned snippets are Pydantic models assert all(isinstance(snippet, KeySnippetModel) for snippet in all_snippets) # Verify the content of each snippet for i, snippet in enumerate(all_snippets): assert snippet.filepath == snippets_data[i]["filepath"] assert snippet.line_number == snippets_data[i]["line_number"] assert snippet.snippet == snippets_data[i]["snippet"] assert snippet.description == snippets_data[i]["description"] def test_get_snippets_dict(setup_db): """Test retrieving key snippets as a dictionary.""" # Set up repository with the in-memory database repo = KeySnippetRepository(db=setup_db) # Create some key snippets snippets = [] snippets_data = [ { "filepath": "file1.py", "line_number": 10, "snippet": "def function1():", "description": "Function 1" }, { "filepath": "file2.py", "line_number": 20, "snippet": "def function2():", "description": "Function 2" }, { "filepath": "file3.py", "line_number": 30, "snippet": "def function3():", "description": "Function 3" } ] for data in snippets_data: snippets.append(repo.create(**data)) # Retrieve snippets as dictionary snippets_dict = repo.get_snippets_dict() # Verify we got the correct number of snippets assert len(snippets_dict) == len(snippets_data) # Verify each snippet is in the dictionary with the correct content for i, snippet in enumerate(snippets): assert snippet.id in snippets_dict assert snippets_dict[snippet.id]["filepath"] == snippets_data[i]["filepath"] assert snippets_dict[snippet.id]["line_number"] == snippets_data[i]["line_number"] assert snippets_dict[snippet.id]["snippet"] == snippets_data[i]["snippet"] assert snippets_dict[snippet.id]["description"] == snippets_data[i]["description"] def test_to_model_conversion(setup_db): """Test conversion from Peewee model to Pydantic model.""" repo = KeySnippetRepository(db=setup_db) # Create a snippet in the database using Peewee directly peewee_snippet = KeySnippet.create( filepath="conversion_test.py", line_number=100, snippet="def conversion_test():", description="Test model conversion" ) # Use the _to_model method to convert it pydantic_snippet = repo._to_model(peewee_snippet) # Verify the conversion was successful assert isinstance(pydantic_snippet, KeySnippetModel) assert pydantic_snippet.id == peewee_snippet.id assert pydantic_snippet.filepath == peewee_snippet.filepath assert pydantic_snippet.line_number == peewee_snippet.line_number assert pydantic_snippet.snippet == peewee_snippet.snippet assert pydantic_snippet.description == peewee_snippet.description # Test conversion of None assert repo._to_model(None) is None