only access snippets via repository

This commit is contained in:
AI Christianson 2025-03-02 19:33:03 -05:00
parent cb3504016f
commit 772ce3e049
2 changed files with 150 additions and 132 deletions

View File

@ -53,8 +53,6 @@ _global_memory: Dict[str, Any] = {
"task_id_counter": 1, # Counter for generating unique task IDs
"key_facts": {}, # Dict[int, str] - ID to fact mapping (deprecated, using DB now)
"key_fact_id_counter": 1, # Counter for generating unique fact IDs (deprecated, using DB now)
"key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping
"key_snippet_id_counter": 1, # Counter for generating unique snippet IDs
"implementation_requested": False,
"related_files": {}, # Dict[int, str] - ID to filepath mapping
"related_file_id_counter": 1, # Counter for generating unique file IDs
@ -190,7 +188,7 @@ def request_implementation() -> str:
@tool("emit_key_snippet")
def emit_key_snippet(snippet_info: SnippetInfo) -> str:
"""Store a single source code snippet in global memory which represents key information.
"""Store a single source code snippet in the database which represents key information.
Automatically adds the filepath of the snippet to related files.
This is for **existing**, or **just-written** files, not for things to be created in the future.
@ -217,14 +215,8 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
description=snippet_info["description"],
)
# For backward compatibility, also store in global memory
if "key_snippets" not in _global_memory:
_global_memory["key_snippets"] = {}
# Use id_counter for compatibility with tests
snippet_id = _global_memory["key_snippet_id_counter"]
_global_memory["key_snippet_id_counter"] += 1
_global_memory["key_snippets"][snippet_id] = snippet_info
# Get the snippet ID from the database record
snippet_id = key_snippet.id
# Format display text as markdown
display_text = [
@ -255,7 +247,7 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
@tool("delete_key_snippets")
def delete_key_snippets(snippet_ids: List[int]) -> str:
"""Delete multiple key snippets from global memory by their IDs.
"""Delete multiple key snippets from the database by their IDs.
Silently skips any IDs that don't exist.
Args:
@ -263,20 +255,13 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
"""
results = []
for snippet_id in snippet_ids:
# Try to delete from database first
# Get the snippet first to capture filepath for the message
snippet = key_snippet_repository.get(snippet_id)
if snippet:
filepath = snippet.filepath
# Delete from database
success = key_snippet_repository.delete(snippet_id)
# For backward compatibility, also delete from global memory
if snippet_id in _global_memory["key_snippets"]:
deleted_snippet = _global_memory["key_snippets"].pop(snippet_id)
filepath = deleted_snippet['filepath']
else:
# If not in memory but successful database delete, use generic message
if success:
filepath = "database"
else:
continue # Skip if not found in either place
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
console.print(
Panel(
@ -606,8 +591,8 @@ def get_memory_value(key: str) -> str:
"""
Get a value from global memory.
Note: Key facts and key snippets are now handled by their respective repositories
and formatter modules, but this function maintains backward compatibility.
Note: Key facts and key snippets are handled by their respective repositories
and formatter modules.
Different memory types return different formats:
- key_snippets: Returns formatted snippets with file path, line number and content
@ -623,61 +608,12 @@ def get_memory_value(key: str) -> str:
"""
if key == "key_snippets":
try:
# Try to get snippets from repository first
# Get snippets from repository
snippets_dict = key_snippet_repository.get_snippets_dict()
if snippets_dict:
return key_snippets_formatter.format_key_snippets_dict(snippets_dict)
# Fallback to global memory for backward compatibility
values = _global_memory.get(key, {})
if not values:
return ""
# Format each snippet with file info and content using markdown
snippets = []
for k, v in sorted(values.items()):
snippet_text = [
f"## 📝 Code Snippet #{k}",
"", # Empty line for better markdown spacing
"**Source Location**:",
f"- File: `{v['filepath']}`",
f"- Line: `{v['line_number']}`",
"", # Empty line before code block
"**Code**:",
"```python",
v["snippet"].rstrip(), # Remove trailing whitespace
"```",
]
if v["description"]:
# Add empty line and description
snippet_text.extend(["", "**Description**:", v["description"]])
snippets.append("\n".join(snippet_text))
return "\n\n".join(snippets)
except Exception as e:
logger.error(f"Error retrieving key snippets: {str(e)}")
# If there's an error with the repository, fall back to global memory
values = _global_memory.get(key, {})
if not values:
return ""
# (Same formatting code as above)
snippets = []
for k, v in sorted(values.items()):
snippet_text = [
f"## 📝 Code Snippet #{k}",
"", # Empty line for better markdown spacing
"**Source Location**:",
f"- File: `{v['filepath']}`",
f"- Line: `{v['line_number']}`",
"", # Empty line before code block
"**Code**:",
"```python",
v["snippet"].rstrip(), # Remove trailing whitespace
"```",
]
if v["description"]:
# Add empty line and description
snippet_text.extend(["", "**Description**:", v["description"]])
snippets.append("\n".join(snippet_text))
return "\n\n".join(snippets)
if key == "work_log":
values = _global_memory.get(key, [])

View File

@ -17,6 +17,7 @@ from ra_aid.tools.memory import (
get_related_files,
get_work_log,
key_fact_repository,
key_snippet_repository,
log_work_event,
reset_work_log,
swap_task_order,
@ -30,8 +31,6 @@ def reset_memory():
"""Reset global memory before each test"""
_global_memory["key_facts"] = {}
_global_memory["key_fact_id_counter"] = 0
_global_memory["key_snippets"] = {}
_global_memory["key_snippet_id_counter"] = 0
_global_memory["research_notes"] = []
_global_memory["plans"] = []
_global_memory["tasks"] = {}
@ -43,8 +42,6 @@ def reset_memory():
# Clean up after test
_global_memory["key_facts"] = {}
_global_memory["key_fact_id_counter"] = 0
_global_memory["key_snippets"] = {}
_global_memory["key_snippet_id_counter"] = 0
_global_memory["research_notes"] = []
_global_memory["plans"] = []
_global_memory["tasks"] = {}
@ -113,6 +110,66 @@ def mock_repository():
yield mock_repo
@pytest.fixture(autouse=True)
def mock_key_snippet_repository():
"""Mock the KeySnippetRepository to avoid database operations during tests"""
with patch('ra_aid.tools.memory.key_snippet_repository') as mock_repo:
# Setup the mock repository to behave like the original, but using memory
snippets = {} # Local in-memory storage
snippet_id_counter = 0
# Mock KeySnippet objects
class MockKeySnippet:
def __init__(self, id, filepath, line_number, snippet, description=None):
self.id = id
self.filepath = filepath
self.line_number = line_number
self.snippet = snippet
self.description = description
# Mock create method
def mock_create(filepath, line_number, snippet, description=None):
nonlocal snippet_id_counter
key_snippet = MockKeySnippet(snippet_id_counter, filepath, line_number, snippet, description)
snippets[snippet_id_counter] = key_snippet
snippet_id_counter += 1
return key_snippet
mock_repo.create.side_effect = mock_create
# Mock get method
def mock_get(snippet_id):
return snippets.get(snippet_id)
mock_repo.get.side_effect = mock_get
# Mock delete method
def mock_delete(snippet_id):
if snippet_id in snippets:
del snippets[snippet_id]
return True
return False
mock_repo.delete.side_effect = mock_delete
# Mock get_snippets_dict method
def mock_get_snippets_dict():
return {
snippet_id: {
"filepath": snippet.filepath,
"line_number": snippet.line_number,
"snippet": snippet.snippet,
"description": snippet.description
}
for snippet_id, snippet in snippets.items()
}
mock_repo.get_snippets_dict.side_effect = mock_get_snippets_dict
# Mock get_all method
def mock_get_all():
return list(snippets.values())
mock_repo.get_all.side_effect = mock_get_all
yield mock_repo
def test_emit_key_facts_single_fact(reset_memory, mock_repository):
"""Test emitting a single key fact using emit_key_facts"""
# Test with single fact
@ -238,7 +295,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
mock_repository.get_all.assert_called_once()
def test_emit_key_snippet(reset_memory):
def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
"""Test emitting a single code snippet"""
# Test snippet with description
snippet = {
@ -254,11 +311,13 @@ def test_emit_key_snippet(reset_memory):
# Verify return message
assert result == "Snippet #0 stored."
# Verify snippet stored correctly
assert _global_memory["key_snippets"][0] == snippet
# Verify counter incremented correctly
assert _global_memory["key_snippet_id_counter"] == 1
# Verify create was called correctly
mock_key_snippet_repository.create.assert_called_with(
filepath="test.py",
line_number=10,
snippet="def test():\n pass",
description="Test function"
)
# Test snippet without description
snippet2 = {
@ -274,16 +333,18 @@ def test_emit_key_snippet(reset_memory):
# Verify return message
assert result == "Snippet #1 stored."
# Verify snippet stored correctly
assert _global_memory["key_snippets"][1] == snippet2
# Verify counter incremented correctly
assert _global_memory["key_snippet_id_counter"] == 2
# Verify create was called correctly
mock_key_snippet_repository.create.assert_called_with(
filepath="main.py",
line_number=20,
snippet="print('hello')",
description=None
)
def test_delete_key_snippets(reset_memory):
def test_delete_key_snippets(reset_memory, mock_key_snippet_repository):
"""Test deleting multiple code snippets"""
# Add test snippets
# Mock snippets
snippets = [
{
"filepath": "test1.py",
@ -308,20 +369,28 @@ def test_delete_key_snippets(reset_memory):
for snippet in snippets:
emit_key_snippet.invoke({"snippet_info": snippet})
# Reset mock to clear call history
mock_key_snippet_repository.reset_mock()
# Test deleting mix of valid and invalid IDs
result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]})
# Verify success message
assert result == "Snippets deleted."
# Verify correct snippets removed
assert 0 not in _global_memory["key_snippets"]
assert 1 not in _global_memory["key_snippets"]
assert 2 in _global_memory["key_snippets"]
assert _global_memory["key_snippets"][2]["filepath"] == "test3.py"
# Verify repository delete was called with correct IDs
mock_key_snippet_repository.get.assert_any_call(0)
mock_key_snippet_repository.get.assert_any_call(1)
mock_key_snippet_repository.get.assert_any_call(999)
mock_key_snippet_repository.delete.assert_any_call(0)
mock_key_snippet_repository.delete.assert_any_call(1)
# Make sure delete wasn't called for ID 999
assert mock_key_snippet_repository.delete.call_count == 2
def test_delete_key_snippets_empty(reset_memory):
def test_delete_key_snippets_empty(reset_memory, mock_key_snippet_repository):
"""Test deleting snippets with empty ID list"""
# Add a test snippet
snippet = {
@ -332,12 +401,15 @@ def test_delete_key_snippets_empty(reset_memory):
}
emit_key_snippet.invoke({"snippet_info": snippet})
# Reset mock to clear call history
mock_key_snippet_repository.reset_mock()
# Test with empty list
result = delete_key_snippets.invoke({"snippet_ids": []})
assert result == "Snippets deleted."
# Verify snippet still exists
assert 0 in _global_memory["key_snippets"]
# Verify no call to delete method
mock_key_snippet_repository.delete.assert_not_called()
def test_emit_related_files_basic(reset_memory, tmp_path):
@ -541,7 +613,7 @@ def test_emit_related_files_path_normalization(reset_memory, tmp_path):
os.chdir(original_dir)
def test_key_snippets_integration(reset_memory, tmp_path):
def test_key_snippets_integration(reset_memory, tmp_path, mock_key_snippet_repository):
"""Integration test for key snippets functionality"""
# Create test files
file1 = tmp_path / "file1.py"
@ -577,7 +649,7 @@ def test_key_snippets_integration(reset_memory, tmp_path):
for i, snippet in enumerate(snippets):
result = emit_key_snippet.invoke({"snippet_info": snippet})
assert result == f"Snippet #{i} stored."
assert _global_memory["key_snippet_id_counter"] == 3
# Verify related files were tracked with IDs
assert len(_global_memory["related_files"]) == 3
# Check files are stored with proper IDs
@ -586,25 +658,25 @@ def test_key_snippets_integration(reset_memory, tmp_path):
assert str(file2) in file_values
assert str(file3) in file_values
# Verify all snippets were stored correctly
assert len(_global_memory["key_snippets"]) == 3
assert _global_memory["key_snippets"][0] == snippets[0]
assert _global_memory["key_snippets"][1] == snippets[1]
assert _global_memory["key_snippets"][2] == snippets[2]
# Verify repository create was called for each snippet
assert mock_key_snippet_repository.create.call_count == 3
# Reset mock to clear call history
mock_key_snippet_repository.reset_mock()
# Delete some but not all snippets (0 and 2)
result = delete_key_snippets.invoke({"snippet_ids": [0, 2]})
assert result == "Snippets deleted."
# Verify remaining snippet is intact
assert len(_global_memory["key_snippets"]) == 1
assert 1 in _global_memory["key_snippets"]
assert _global_memory["key_snippets"][1] == snippets[1]
# Verify delete was called for the correct IDs
mock_key_snippet_repository.delete.assert_any_call(0)
mock_key_snippet_repository.delete.assert_any_call(2)
assert mock_key_snippet_repository.delete.call_count == 2
# Counter should remain unchanged after deletions
assert _global_memory["key_snippet_id_counter"] == 3
# Reset mock again
mock_key_snippet_repository.reset_mock()
# Add new snippet to verify counter continues correctly
# Add new snippet
file4 = tmp_path / "file4.py"
file4.write_text("def func4():\n return False")
new_snippet = {
@ -615,21 +687,31 @@ def test_key_snippets_integration(reset_memory, tmp_path):
}
result = emit_key_snippet.invoke({"snippet_info": new_snippet})
assert result == "Snippet #3 stored."
assert _global_memory["key_snippet_id_counter"] == 4
# Verify create was called with correct params
mock_key_snippet_repository.create.assert_called_with(
filepath=str(file4),
line_number=40,
snippet="def func4():\n return False",
description="Fourth function"
)
# Verify new file was added to related files
file_values = _global_memory["related_files"].values()
assert str(file4) in file_values
assert len(_global_memory["related_files"]) == 4
# Reset mock again
mock_key_snippet_repository.reset_mock()
# Delete remaining snippets
result = delete_key_snippets.invoke({"snippet_ids": [1, 3]})
assert result == "Snippets deleted."
# Verify all snippets are gone
assert len(_global_memory["key_snippets"]) == 0
# Counter should still maintain its value
assert _global_memory["key_snippet_id_counter"] == 4
# Verify delete was called for the correct IDs
mock_key_snippet_repository.delete.assert_any_call(1)
mock_key_snippet_repository.delete.assert_any_call(3)
assert mock_key_snippet_repository.delete.call_count == 2
def test_emit_task_with_id(reset_memory):