only access snippets via repository

This commit is contained in:
AI Christianson 2025-03-02 19:33:03 -05:00
parent cb3504016f
commit 772ce3e049
2 changed files with 150 additions and 132 deletions

View File

@ -53,8 +53,6 @@ _global_memory: Dict[str, Any] = {
"task_id_counter": 1, # Counter for generating unique task IDs "task_id_counter": 1, # Counter for generating unique task IDs
"key_facts": {}, # Dict[int, str] - ID to fact mapping (deprecated, using DB now) "key_facts": {}, # Dict[int, str] - ID to fact mapping (deprecated, using DB now)
"key_fact_id_counter": 1, # Counter for generating unique fact IDs (deprecated, using DB now) "key_fact_id_counter": 1, # Counter for generating unique fact IDs (deprecated, using DB now)
"key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping
"key_snippet_id_counter": 1, # Counter for generating unique snippet IDs
"implementation_requested": False, "implementation_requested": False,
"related_files": {}, # Dict[int, str] - ID to filepath mapping "related_files": {}, # Dict[int, str] - ID to filepath mapping
"related_file_id_counter": 1, # Counter for generating unique file IDs "related_file_id_counter": 1, # Counter for generating unique file IDs
@ -190,7 +188,7 @@ def request_implementation() -> str:
@tool("emit_key_snippet") @tool("emit_key_snippet")
def emit_key_snippet(snippet_info: SnippetInfo) -> str: def emit_key_snippet(snippet_info: SnippetInfo) -> str:
"""Store a single source code snippet in global memory which represents key information. """Store a single source code snippet in the database which represents key information.
Automatically adds the filepath of the snippet to related files. Automatically adds the filepath of the snippet to related files.
This is for **existing**, or **just-written** files, not for things to be created in the future. This is for **existing**, or **just-written** files, not for things to be created in the future.
@ -217,14 +215,8 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
description=snippet_info["description"], description=snippet_info["description"],
) )
# For backward compatibility, also store in global memory # Get the snippet ID from the database record
if "key_snippets" not in _global_memory: snippet_id = key_snippet.id
_global_memory["key_snippets"] = {}
# Use id_counter for compatibility with tests
snippet_id = _global_memory["key_snippet_id_counter"]
_global_memory["key_snippet_id_counter"] += 1
_global_memory["key_snippets"][snippet_id] = snippet_info
# Format display text as markdown # Format display text as markdown
display_text = [ display_text = [
@ -255,7 +247,7 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
@tool("delete_key_snippets") @tool("delete_key_snippets")
def delete_key_snippets(snippet_ids: List[int]) -> str: def delete_key_snippets(snippet_ids: List[int]) -> str:
"""Delete multiple key snippets from global memory by their IDs. """Delete multiple key snippets from the database by their IDs.
Silently skips any IDs that don't exist. Silently skips any IDs that don't exist.
Args: Args:
@ -263,27 +255,20 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
""" """
results = [] results = []
for snippet_id in snippet_ids: for snippet_id in snippet_ids:
# Try to delete from database first # Get the snippet first to capture filepath for the message
success = key_snippet_repository.delete(snippet_id) snippet = key_snippet_repository.get(snippet_id)
if snippet:
# For backward compatibility, also delete from global memory filepath = snippet.filepath
if snippet_id in _global_memory["key_snippets"]: # Delete from database
deleted_snippet = _global_memory["key_snippets"].pop(snippet_id) success = key_snippet_repository.delete(snippet_id)
filepath = deleted_snippet['filepath']
else:
# If not in memory but successful database delete, use generic message
if success: if success:
filepath = "database" success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
else: console.print(
continue # Skip if not found in either place Panel(
Markdown(success_msg), title="Snippet Deleted", border_style="green"
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}" )
console.print( )
Panel( results.append(success_msg)
Markdown(success_msg), title="Snippet Deleted", border_style="green"
)
)
results.append(success_msg)
log_work_event(f"Deleted snippets {snippet_ids}.") log_work_event(f"Deleted snippets {snippet_ids}.")
return "Snippets deleted." return "Snippets deleted."
@ -606,8 +591,8 @@ def get_memory_value(key: str) -> str:
""" """
Get a value from global memory. Get a value from global memory.
Note: Key facts and key snippets are now handled by their respective repositories Note: Key facts and key snippets are handled by their respective repositories
and formatter modules, but this function maintains backward compatibility. and formatter modules.
Different memory types return different formats: Different memory types return different formats:
- key_snippets: Returns formatted snippets with file path, line number and content - key_snippets: Returns formatted snippets with file path, line number and content
@ -623,61 +608,12 @@ def get_memory_value(key: str) -> str:
""" """
if key == "key_snippets": if key == "key_snippets":
try: try:
# Try to get snippets from repository first # Get snippets from repository
snippets_dict = key_snippet_repository.get_snippets_dict() snippets_dict = key_snippet_repository.get_snippets_dict()
if snippets_dict: return key_snippets_formatter.format_key_snippets_dict(snippets_dict)
return key_snippets_formatter.format_key_snippets_dict(snippets_dict)
# Fallback to global memory for backward compatibility
values = _global_memory.get(key, {})
if not values:
return ""
# Format each snippet with file info and content using markdown
snippets = []
for k, v in sorted(values.items()):
snippet_text = [
f"## 📝 Code Snippet #{k}",
"", # Empty line for better markdown spacing
"**Source Location**:",
f"- File: `{v['filepath']}`",
f"- Line: `{v['line_number']}`",
"", # Empty line before code block
"**Code**:",
"```python",
v["snippet"].rstrip(), # Remove trailing whitespace
"```",
]
if v["description"]:
# Add empty line and description
snippet_text.extend(["", "**Description**:", v["description"]])
snippets.append("\n".join(snippet_text))
return "\n\n".join(snippets)
except Exception as e: except Exception as e:
logger.error(f"Error retrieving key snippets: {str(e)}") logger.error(f"Error retrieving key snippets: {str(e)}")
# If there's an error with the repository, fall back to global memory return ""
values = _global_memory.get(key, {})
if not values:
return ""
# (Same formatting code as above)
snippets = []
for k, v in sorted(values.items()):
snippet_text = [
f"## 📝 Code Snippet #{k}",
"", # Empty line for better markdown spacing
"**Source Location**:",
f"- File: `{v['filepath']}`",
f"- Line: `{v['line_number']}`",
"", # Empty line before code block
"**Code**:",
"```python",
v["snippet"].rstrip(), # Remove trailing whitespace
"```",
]
if v["description"]:
# Add empty line and description
snippet_text.extend(["", "**Description**:", v["description"]])
snippets.append("\n".join(snippet_text))
return "\n\n".join(snippets)
if key == "work_log": if key == "work_log":
values = _global_memory.get(key, []) values = _global_memory.get(key, [])

View File

@ -17,6 +17,7 @@ from ra_aid.tools.memory import (
get_related_files, get_related_files,
get_work_log, get_work_log,
key_fact_repository, key_fact_repository,
key_snippet_repository,
log_work_event, log_work_event,
reset_work_log, reset_work_log,
swap_task_order, swap_task_order,
@ -30,8 +31,6 @@ def reset_memory():
"""Reset global memory before each test""" """Reset global memory before each test"""
_global_memory["key_facts"] = {} _global_memory["key_facts"] = {}
_global_memory["key_fact_id_counter"] = 0 _global_memory["key_fact_id_counter"] = 0
_global_memory["key_snippets"] = {}
_global_memory["key_snippet_id_counter"] = 0
_global_memory["research_notes"] = [] _global_memory["research_notes"] = []
_global_memory["plans"] = [] _global_memory["plans"] = []
_global_memory["tasks"] = {} _global_memory["tasks"] = {}
@ -43,8 +42,6 @@ def reset_memory():
# Clean up after test # Clean up after test
_global_memory["key_facts"] = {} _global_memory["key_facts"] = {}
_global_memory["key_fact_id_counter"] = 0 _global_memory["key_fact_id_counter"] = 0
_global_memory["key_snippets"] = {}
_global_memory["key_snippet_id_counter"] = 0
_global_memory["research_notes"] = [] _global_memory["research_notes"] = []
_global_memory["plans"] = [] _global_memory["plans"] = []
_global_memory["tasks"] = {} _global_memory["tasks"] = {}
@ -113,6 +110,66 @@ def mock_repository():
yield mock_repo yield mock_repo
@pytest.fixture(autouse=True)
def mock_key_snippet_repository():
"""Mock the KeySnippetRepository to avoid database operations during tests"""
with patch('ra_aid.tools.memory.key_snippet_repository') as mock_repo:
# Setup the mock repository to behave like the original, but using memory
snippets = {} # Local in-memory storage
snippet_id_counter = 0
# Mock KeySnippet objects
class MockKeySnippet:
def __init__(self, id, filepath, line_number, snippet, description=None):
self.id = id
self.filepath = filepath
self.line_number = line_number
self.snippet = snippet
self.description = description
# Mock create method
def mock_create(filepath, line_number, snippet, description=None):
nonlocal snippet_id_counter
key_snippet = MockKeySnippet(snippet_id_counter, filepath, line_number, snippet, description)
snippets[snippet_id_counter] = key_snippet
snippet_id_counter += 1
return key_snippet
mock_repo.create.side_effect = mock_create
# Mock get method
def mock_get(snippet_id):
return snippets.get(snippet_id)
mock_repo.get.side_effect = mock_get
# Mock delete method
def mock_delete(snippet_id):
if snippet_id in snippets:
del snippets[snippet_id]
return True
return False
mock_repo.delete.side_effect = mock_delete
# Mock get_snippets_dict method
def mock_get_snippets_dict():
return {
snippet_id: {
"filepath": snippet.filepath,
"line_number": snippet.line_number,
"snippet": snippet.snippet,
"description": snippet.description
}
for snippet_id, snippet in snippets.items()
}
mock_repo.get_snippets_dict.side_effect = mock_get_snippets_dict
# Mock get_all method
def mock_get_all():
return list(snippets.values())
mock_repo.get_all.side_effect = mock_get_all
yield mock_repo
def test_emit_key_facts_single_fact(reset_memory, mock_repository): def test_emit_key_facts_single_fact(reset_memory, mock_repository):
"""Test emitting a single key fact using emit_key_facts""" """Test emitting a single key fact using emit_key_facts"""
# Test with single fact # Test with single fact
@ -238,7 +295,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
mock_repository.get_all.assert_called_once() mock_repository.get_all.assert_called_once()
def test_emit_key_snippet(reset_memory): def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
"""Test emitting a single code snippet""" """Test emitting a single code snippet"""
# Test snippet with description # Test snippet with description
snippet = { snippet = {
@ -254,11 +311,13 @@ def test_emit_key_snippet(reset_memory):
# Verify return message # Verify return message
assert result == "Snippet #0 stored." assert result == "Snippet #0 stored."
# Verify snippet stored correctly # Verify create was called correctly
assert _global_memory["key_snippets"][0] == snippet mock_key_snippet_repository.create.assert_called_with(
filepath="test.py",
# Verify counter incremented correctly line_number=10,
assert _global_memory["key_snippet_id_counter"] == 1 snippet="def test():\n pass",
description="Test function"
)
# Test snippet without description # Test snippet without description
snippet2 = { snippet2 = {
@ -274,16 +333,18 @@ def test_emit_key_snippet(reset_memory):
# Verify return message # Verify return message
assert result == "Snippet #1 stored." assert result == "Snippet #1 stored."
# Verify snippet stored correctly # Verify create was called correctly
assert _global_memory["key_snippets"][1] == snippet2 mock_key_snippet_repository.create.assert_called_with(
filepath="main.py",
# Verify counter incremented correctly line_number=20,
assert _global_memory["key_snippet_id_counter"] == 2 snippet="print('hello')",
description=None
)
def test_delete_key_snippets(reset_memory): def test_delete_key_snippets(reset_memory, mock_key_snippet_repository):
"""Test deleting multiple code snippets""" """Test deleting multiple code snippets"""
# Add test snippets # Mock snippets
snippets = [ snippets = [
{ {
"filepath": "test1.py", "filepath": "test1.py",
@ -308,20 +369,28 @@ def test_delete_key_snippets(reset_memory):
for snippet in snippets: for snippet in snippets:
emit_key_snippet.invoke({"snippet_info": snippet}) emit_key_snippet.invoke({"snippet_info": snippet})
# Reset mock to clear call history
mock_key_snippet_repository.reset_mock()
# Test deleting mix of valid and invalid IDs # Test deleting mix of valid and invalid IDs
result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]}) result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]})
# Verify success message # Verify success message
assert result == "Snippets deleted." assert result == "Snippets deleted."
# Verify correct snippets removed # Verify repository delete was called with correct IDs
assert 0 not in _global_memory["key_snippets"] mock_key_snippet_repository.get.assert_any_call(0)
assert 1 not in _global_memory["key_snippets"] mock_key_snippet_repository.get.assert_any_call(1)
assert 2 in _global_memory["key_snippets"] mock_key_snippet_repository.get.assert_any_call(999)
assert _global_memory["key_snippets"][2]["filepath"] == "test3.py"
mock_key_snippet_repository.delete.assert_any_call(0)
mock_key_snippet_repository.delete.assert_any_call(1)
# Make sure delete wasn't called for ID 999
assert mock_key_snippet_repository.delete.call_count == 2
def test_delete_key_snippets_empty(reset_memory): def test_delete_key_snippets_empty(reset_memory, mock_key_snippet_repository):
"""Test deleting snippets with empty ID list""" """Test deleting snippets with empty ID list"""
# Add a test snippet # Add a test snippet
snippet = { snippet = {
@ -331,13 +400,16 @@ def test_delete_key_snippets_empty(reset_memory):
"description": None, "description": None,
} }
emit_key_snippet.invoke({"snippet_info": snippet}) emit_key_snippet.invoke({"snippet_info": snippet})
# Reset mock to clear call history
mock_key_snippet_repository.reset_mock()
# Test with empty list # Test with empty list
result = delete_key_snippets.invoke({"snippet_ids": []}) result = delete_key_snippets.invoke({"snippet_ids": []})
assert result == "Snippets deleted." assert result == "Snippets deleted."
# Verify snippet still exists # Verify no call to delete method
assert 0 in _global_memory["key_snippets"] mock_key_snippet_repository.delete.assert_not_called()
def test_emit_related_files_basic(reset_memory, tmp_path): def test_emit_related_files_basic(reset_memory, tmp_path):
@ -541,7 +613,7 @@ def test_emit_related_files_path_normalization(reset_memory, tmp_path):
os.chdir(original_dir) os.chdir(original_dir)
def test_key_snippets_integration(reset_memory, tmp_path): def test_key_snippets_integration(reset_memory, tmp_path, mock_key_snippet_repository):
"""Integration test for key snippets functionality""" """Integration test for key snippets functionality"""
# Create test files # Create test files
file1 = tmp_path / "file1.py" file1 = tmp_path / "file1.py"
@ -577,7 +649,7 @@ def test_key_snippets_integration(reset_memory, tmp_path):
for i, snippet in enumerate(snippets): for i, snippet in enumerate(snippets):
result = emit_key_snippet.invoke({"snippet_info": snippet}) result = emit_key_snippet.invoke({"snippet_info": snippet})
assert result == f"Snippet #{i} stored." assert result == f"Snippet #{i} stored."
assert _global_memory["key_snippet_id_counter"] == 3
# Verify related files were tracked with IDs # Verify related files were tracked with IDs
assert len(_global_memory["related_files"]) == 3 assert len(_global_memory["related_files"]) == 3
# Check files are stored with proper IDs # Check files are stored with proper IDs
@ -586,25 +658,25 @@ def test_key_snippets_integration(reset_memory, tmp_path):
assert str(file2) in file_values assert str(file2) in file_values
assert str(file3) in file_values assert str(file3) in file_values
# Verify all snippets were stored correctly # Verify repository create was called for each snippet
assert len(_global_memory["key_snippets"]) == 3 assert mock_key_snippet_repository.create.call_count == 3
assert _global_memory["key_snippets"][0] == snippets[0]
assert _global_memory["key_snippets"][1] == snippets[1] # Reset mock to clear call history
assert _global_memory["key_snippets"][2] == snippets[2] mock_key_snippet_repository.reset_mock()
# Delete some but not all snippets (0 and 2) # Delete some but not all snippets (0 and 2)
result = delete_key_snippets.invoke({"snippet_ids": [0, 2]}) result = delete_key_snippets.invoke({"snippet_ids": [0, 2]})
assert result == "Snippets deleted." assert result == "Snippets deleted."
# Verify remaining snippet is intact # Verify delete was called for the correct IDs
assert len(_global_memory["key_snippets"]) == 1 mock_key_snippet_repository.delete.assert_any_call(0)
assert 1 in _global_memory["key_snippets"] mock_key_snippet_repository.delete.assert_any_call(2)
assert _global_memory["key_snippets"][1] == snippets[1] assert mock_key_snippet_repository.delete.call_count == 2
# Reset mock again
mock_key_snippet_repository.reset_mock()
# Counter should remain unchanged after deletions # Add new snippet
assert _global_memory["key_snippet_id_counter"] == 3
# Add new snippet to verify counter continues correctly
file4 = tmp_path / "file4.py" file4 = tmp_path / "file4.py"
file4.write_text("def func4():\n return False") file4.write_text("def func4():\n return False")
new_snippet = { new_snippet = {
@ -615,21 +687,31 @@ def test_key_snippets_integration(reset_memory, tmp_path):
} }
result = emit_key_snippet.invoke({"snippet_info": new_snippet}) result = emit_key_snippet.invoke({"snippet_info": new_snippet})
assert result == "Snippet #3 stored." assert result == "Snippet #3 stored."
assert _global_memory["key_snippet_id_counter"] == 4
# Verify create was called with correct params
mock_key_snippet_repository.create.assert_called_with(
filepath=str(file4),
line_number=40,
snippet="def func4():\n return False",
description="Fourth function"
)
# Verify new file was added to related files # Verify new file was added to related files
file_values = _global_memory["related_files"].values() file_values = _global_memory["related_files"].values()
assert str(file4) in file_values assert str(file4) in file_values
assert len(_global_memory["related_files"]) == 4 assert len(_global_memory["related_files"]) == 4
# Reset mock again
mock_key_snippet_repository.reset_mock()
# Delete remaining snippets # Delete remaining snippets
result = delete_key_snippets.invoke({"snippet_ids": [1, 3]}) result = delete_key_snippets.invoke({"snippet_ids": [1, 3]})
assert result == "Snippets deleted." assert result == "Snippets deleted."
# Verify all snippets are gone # Verify delete was called for the correct IDs
assert len(_global_memory["key_snippets"]) == 0 mock_key_snippet_repository.delete.assert_any_call(1)
mock_key_snippet_repository.delete.assert_any_call(3)
# Counter should still maintain its value assert mock_key_snippet_repository.delete.call_count == 2
assert _global_memory["key_snippet_id_counter"] == 4
def test_emit_task_with_id(reset_memory): def test_emit_task_with_id(reset_memory):