From 772ce3e049cc7d2a03a4bb6e0c4ce490101b9f40 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Sun, 2 Mar 2025 19:33:03 -0500 Subject: [PATCH] only access snippets via repository --- ra_aid/tools/memory.py | 108 ++++--------------- tests/ra_aid/tools/test_memory.py | 174 ++++++++++++++++++++++-------- 2 files changed, 150 insertions(+), 132 deletions(-) diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index 817a77d..7c83ec4 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -53,8 +53,6 @@ _global_memory: Dict[str, Any] = { "task_id_counter": 1, # Counter for generating unique task IDs "key_facts": {}, # Dict[int, str] - ID to fact mapping (deprecated, using DB now) "key_fact_id_counter": 1, # Counter for generating unique fact IDs (deprecated, using DB now) - "key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping - "key_snippet_id_counter": 1, # Counter for generating unique snippet IDs "implementation_requested": False, "related_files": {}, # Dict[int, str] - ID to filepath mapping "related_file_id_counter": 1, # Counter for generating unique file IDs @@ -190,7 +188,7 @@ def request_implementation() -> str: @tool("emit_key_snippet") def emit_key_snippet(snippet_info: SnippetInfo) -> str: - """Store a single source code snippet in global memory which represents key information. + """Store a single source code snippet in the database which represents key information. Automatically adds the filepath of the snippet to related files. This is for **existing**, or **just-written** files, not for things to be created in the future. @@ -217,14 +215,8 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str: description=snippet_info["description"], ) - # For backward compatibility, also store in global memory - if "key_snippets" not in _global_memory: - _global_memory["key_snippets"] = {} - - # Use id_counter for compatibility with tests - snippet_id = _global_memory["key_snippet_id_counter"] - _global_memory["key_snippet_id_counter"] += 1 - _global_memory["key_snippets"][snippet_id] = snippet_info + # Get the snippet ID from the database record + snippet_id = key_snippet.id # Format display text as markdown display_text = [ @@ -255,7 +247,7 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str: @tool("delete_key_snippets") def delete_key_snippets(snippet_ids: List[int]) -> str: - """Delete multiple key snippets from global memory by their IDs. + """Delete multiple key snippets from the database by their IDs. Silently skips any IDs that don't exist. Args: @@ -263,27 +255,20 @@ def delete_key_snippets(snippet_ids: List[int]) -> str: """ results = [] for snippet_id in snippet_ids: - # Try to delete from database first - success = key_snippet_repository.delete(snippet_id) - - # For backward compatibility, also delete from global memory - if snippet_id in _global_memory["key_snippets"]: - deleted_snippet = _global_memory["key_snippets"].pop(snippet_id) - filepath = deleted_snippet['filepath'] - else: - # If not in memory but successful database delete, use generic message + # Get the snippet first to capture filepath for the message + snippet = key_snippet_repository.get(snippet_id) + if snippet: + filepath = snippet.filepath + # Delete from database + success = key_snippet_repository.delete(snippet_id) if success: - filepath = "database" - else: - continue # Skip if not found in either place - - success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}" - console.print( - Panel( - Markdown(success_msg), title="Snippet Deleted", border_style="green" - ) - ) - results.append(success_msg) + success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}" + console.print( + Panel( + Markdown(success_msg), title="Snippet Deleted", border_style="green" + ) + ) + results.append(success_msg) log_work_event(f"Deleted snippets {snippet_ids}.") return "Snippets deleted." @@ -606,8 +591,8 @@ def get_memory_value(key: str) -> str: """ Get a value from global memory. - Note: Key facts and key snippets are now handled by their respective repositories - and formatter modules, but this function maintains backward compatibility. + Note: Key facts and key snippets are handled by their respective repositories + and formatter modules. Different memory types return different formats: - key_snippets: Returns formatted snippets with file path, line number and content @@ -623,61 +608,12 @@ def get_memory_value(key: str) -> str: """ if key == "key_snippets": try: - # Try to get snippets from repository first + # Get snippets from repository snippets_dict = key_snippet_repository.get_snippets_dict() - if snippets_dict: - return key_snippets_formatter.format_key_snippets_dict(snippets_dict) - - # Fallback to global memory for backward compatibility - values = _global_memory.get(key, {}) - if not values: - return "" - # Format each snippet with file info and content using markdown - snippets = [] - for k, v in sorted(values.items()): - snippet_text = [ - f"## 📝 Code Snippet #{k}", - "", # Empty line for better markdown spacing - "**Source Location**:", - f"- File: `{v['filepath']}`", - f"- Line: `{v['line_number']}`", - "", # Empty line before code block - "**Code**:", - "```python", - v["snippet"].rstrip(), # Remove trailing whitespace - "```", - ] - if v["description"]: - # Add empty line and description - snippet_text.extend(["", "**Description**:", v["description"]]) - snippets.append("\n".join(snippet_text)) - return "\n\n".join(snippets) + return key_snippets_formatter.format_key_snippets_dict(snippets_dict) except Exception as e: logger.error(f"Error retrieving key snippets: {str(e)}") - # If there's an error with the repository, fall back to global memory - values = _global_memory.get(key, {}) - if not values: - return "" - # (Same formatting code as above) - snippets = [] - for k, v in sorted(values.items()): - snippet_text = [ - f"## 📝 Code Snippet #{k}", - "", # Empty line for better markdown spacing - "**Source Location**:", - f"- File: `{v['filepath']}`", - f"- Line: `{v['line_number']}`", - "", # Empty line before code block - "**Code**:", - "```python", - v["snippet"].rstrip(), # Remove trailing whitespace - "```", - ] - if v["description"]: - # Add empty line and description - snippet_text.extend(["", "**Description**:", v["description"]]) - snippets.append("\n".join(snippet_text)) - return "\n\n".join(snippets) + return "" if key == "work_log": values = _global_memory.get(key, []) diff --git a/tests/ra_aid/tools/test_memory.py b/tests/ra_aid/tools/test_memory.py index db8cfbf..30e85be 100644 --- a/tests/ra_aid/tools/test_memory.py +++ b/tests/ra_aid/tools/test_memory.py @@ -17,6 +17,7 @@ from ra_aid.tools.memory import ( get_related_files, get_work_log, key_fact_repository, + key_snippet_repository, log_work_event, reset_work_log, swap_task_order, @@ -30,8 +31,6 @@ def reset_memory(): """Reset global memory before each test""" _global_memory["key_facts"] = {} _global_memory["key_fact_id_counter"] = 0 - _global_memory["key_snippets"] = {} - _global_memory["key_snippet_id_counter"] = 0 _global_memory["research_notes"] = [] _global_memory["plans"] = [] _global_memory["tasks"] = {} @@ -43,8 +42,6 @@ def reset_memory(): # Clean up after test _global_memory["key_facts"] = {} _global_memory["key_fact_id_counter"] = 0 - _global_memory["key_snippets"] = {} - _global_memory["key_snippet_id_counter"] = 0 _global_memory["research_notes"] = [] _global_memory["plans"] = [] _global_memory["tasks"] = {} @@ -113,6 +110,66 @@ def mock_repository(): yield mock_repo +@pytest.fixture(autouse=True) +def mock_key_snippet_repository(): + """Mock the KeySnippetRepository to avoid database operations during tests""" + with patch('ra_aid.tools.memory.key_snippet_repository') as mock_repo: + # Setup the mock repository to behave like the original, but using memory + snippets = {} # Local in-memory storage + snippet_id_counter = 0 + + # Mock KeySnippet objects + class MockKeySnippet: + def __init__(self, id, filepath, line_number, snippet, description=None): + self.id = id + self.filepath = filepath + self.line_number = line_number + self.snippet = snippet + self.description = description + + # Mock create method + def mock_create(filepath, line_number, snippet, description=None): + nonlocal snippet_id_counter + key_snippet = MockKeySnippet(snippet_id_counter, filepath, line_number, snippet, description) + snippets[snippet_id_counter] = key_snippet + snippet_id_counter += 1 + return key_snippet + mock_repo.create.side_effect = mock_create + + # Mock get method + def mock_get(snippet_id): + return snippets.get(snippet_id) + mock_repo.get.side_effect = mock_get + + # Mock delete method + def mock_delete(snippet_id): + if snippet_id in snippets: + del snippets[snippet_id] + return True + return False + mock_repo.delete.side_effect = mock_delete + + # Mock get_snippets_dict method + def mock_get_snippets_dict(): + return { + snippet_id: { + "filepath": snippet.filepath, + "line_number": snippet.line_number, + "snippet": snippet.snippet, + "description": snippet.description + } + for snippet_id, snippet in snippets.items() + } + mock_repo.get_snippets_dict.side_effect = mock_get_snippets_dict + + # Mock get_all method + def mock_get_all(): + return list(snippets.values()) + mock_repo.get_all.side_effect = mock_get_all + + yield mock_repo + + def test_emit_key_facts_single_fact(reset_memory, mock_repository): """Test emitting a single key fact using emit_key_facts""" # Test with single fact @@ -238,7 +295,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository): mock_repository.get_all.assert_called_once() -def test_emit_key_snippet(reset_memory): +def test_emit_key_snippet(reset_memory, mock_key_snippet_repository): """Test emitting a single code snippet""" # Test snippet with description snippet = { @@ -254,11 +311,13 @@ def test_emit_key_snippet(reset_memory): # Verify return message assert result == "Snippet #0 stored." - # Verify snippet stored correctly - assert _global_memory["key_snippets"][0] == snippet - - # Verify counter incremented correctly - assert _global_memory["key_snippet_id_counter"] == 1 + # Verify create was called correctly + mock_key_snippet_repository.create.assert_called_with( + filepath="test.py", + line_number=10, + snippet="def test():\n pass", + description="Test function" + ) # Test snippet without description snippet2 = { @@ -274,16 +333,18 @@ def test_emit_key_snippet(reset_memory): # Verify return message assert result == "Snippet #1 stored." - # Verify snippet stored correctly - assert _global_memory["key_snippets"][1] == snippet2 - - # Verify counter incremented correctly - assert _global_memory["key_snippet_id_counter"] == 2 + # Verify create was called correctly + mock_key_snippet_repository.create.assert_called_with( + filepath="main.py", + line_number=20, + snippet="print('hello')", + description=None + ) -def test_delete_key_snippets(reset_memory): +def test_delete_key_snippets(reset_memory, mock_key_snippet_repository): """Test deleting multiple code snippets""" - # Add test snippets + # Mock snippets snippets = [ { "filepath": "test1.py", @@ -308,20 +369,28 @@ def test_delete_key_snippets(reset_memory): for snippet in snippets: emit_key_snippet.invoke({"snippet_info": snippet}) + # Reset mock to clear call history + mock_key_snippet_repository.reset_mock() + # Test deleting mix of valid and invalid IDs result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]}) # Verify success message assert result == "Snippets deleted." - # Verify correct snippets removed - assert 0 not in _global_memory["key_snippets"] - assert 1 not in _global_memory["key_snippets"] - assert 2 in _global_memory["key_snippets"] - assert _global_memory["key_snippets"][2]["filepath"] == "test3.py" + # Verify repository delete was called with correct IDs + mock_key_snippet_repository.get.assert_any_call(0) + mock_key_snippet_repository.get.assert_any_call(1) + mock_key_snippet_repository.get.assert_any_call(999) + + mock_key_snippet_repository.delete.assert_any_call(0) + mock_key_snippet_repository.delete.assert_any_call(1) + + # Make sure delete wasn't called for ID 999 + assert mock_key_snippet_repository.delete.call_count == 2 -def test_delete_key_snippets_empty(reset_memory): +def test_delete_key_snippets_empty(reset_memory, mock_key_snippet_repository): """Test deleting snippets with empty ID list""" # Add a test snippet snippet = { @@ -331,13 +400,16 @@ def test_delete_key_snippets_empty(reset_memory): "description": None, } emit_key_snippet.invoke({"snippet_info": snippet}) + + # Reset mock to clear call history + mock_key_snippet_repository.reset_mock() # Test with empty list result = delete_key_snippets.invoke({"snippet_ids": []}) assert result == "Snippets deleted." - # Verify snippet still exists - assert 0 in _global_memory["key_snippets"] + # Verify no call to delete method + mock_key_snippet_repository.delete.assert_not_called() def test_emit_related_files_basic(reset_memory, tmp_path): @@ -541,7 +613,7 @@ def test_emit_related_files_path_normalization(reset_memory, tmp_path): os.chdir(original_dir) -def test_key_snippets_integration(reset_memory, tmp_path): +def test_key_snippets_integration(reset_memory, tmp_path, mock_key_snippet_repository): """Integration test for key snippets functionality""" # Create test files file1 = tmp_path / "file1.py" @@ -577,7 +649,7 @@ def test_key_snippets_integration(reset_memory, tmp_path): for i, snippet in enumerate(snippets): result = emit_key_snippet.invoke({"snippet_info": snippet}) assert result == f"Snippet #{i} stored." - assert _global_memory["key_snippet_id_counter"] == 3 + # Verify related files were tracked with IDs assert len(_global_memory["related_files"]) == 3 # Check files are stored with proper IDs @@ -586,25 +658,25 @@ def test_key_snippets_integration(reset_memory, tmp_path): assert str(file2) in file_values assert str(file3) in file_values - # Verify all snippets were stored correctly - assert len(_global_memory["key_snippets"]) == 3 - assert _global_memory["key_snippets"][0] == snippets[0] - assert _global_memory["key_snippets"][1] == snippets[1] - assert _global_memory["key_snippets"][2] == snippets[2] + # Verify repository create was called for each snippet + assert mock_key_snippet_repository.create.call_count == 3 + + # Reset mock to clear call history + mock_key_snippet_repository.reset_mock() # Delete some but not all snippets (0 and 2) result = delete_key_snippets.invoke({"snippet_ids": [0, 2]}) assert result == "Snippets deleted." - # Verify remaining snippet is intact - assert len(_global_memory["key_snippets"]) == 1 - assert 1 in _global_memory["key_snippets"] - assert _global_memory["key_snippets"][1] == snippets[1] + # Verify delete was called for the correct IDs + mock_key_snippet_repository.delete.assert_any_call(0) + mock_key_snippet_repository.delete.assert_any_call(2) + assert mock_key_snippet_repository.delete.call_count == 2 + + # Reset mock again + mock_key_snippet_repository.reset_mock() - # Counter should remain unchanged after deletions - assert _global_memory["key_snippet_id_counter"] == 3 - - # Add new snippet to verify counter continues correctly + # Add new snippet file4 = tmp_path / "file4.py" file4.write_text("def func4():\n return False") new_snippet = { @@ -615,21 +687,31 @@ def test_key_snippets_integration(reset_memory, tmp_path): } result = emit_key_snippet.invoke({"snippet_info": new_snippet}) assert result == "Snippet #3 stored." - assert _global_memory["key_snippet_id_counter"] == 4 + + # Verify create was called with correct params + mock_key_snippet_repository.create.assert_called_with( + filepath=str(file4), + line_number=40, + snippet="def func4():\n return False", + description="Fourth function" + ) + # Verify new file was added to related files file_values = _global_memory["related_files"].values() assert str(file4) in file_values assert len(_global_memory["related_files"]) == 4 + + # Reset mock again + mock_key_snippet_repository.reset_mock() # Delete remaining snippets result = delete_key_snippets.invoke({"snippet_ids": [1, 3]}) assert result == "Snippets deleted." - # Verify all snippets are gone - assert len(_global_memory["key_snippets"]) == 0 - - # Counter should still maintain its value - assert _global_memory["key_snippet_id_counter"] == 4 + # Verify delete was called for the correct IDs + mock_key_snippet_repository.delete.assert_any_call(1) + mock_key_snippet_repository.delete.assert_any_call(3) + assert mock_key_snippet_repository.delete.call_count == 2 def test_emit_task_with_id(reset_memory):