only access snippets via repository
This commit is contained in:
parent
cb3504016f
commit
772ce3e049
|
|
@ -53,8 +53,6 @@ _global_memory: Dict[str, Any] = {
|
|||
"task_id_counter": 1, # Counter for generating unique task IDs
|
||||
"key_facts": {}, # Dict[int, str] - ID to fact mapping (deprecated, using DB now)
|
||||
"key_fact_id_counter": 1, # Counter for generating unique fact IDs (deprecated, using DB now)
|
||||
"key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping
|
||||
"key_snippet_id_counter": 1, # Counter for generating unique snippet IDs
|
||||
"implementation_requested": False,
|
||||
"related_files": {}, # Dict[int, str] - ID to filepath mapping
|
||||
"related_file_id_counter": 1, # Counter for generating unique file IDs
|
||||
|
|
@ -190,7 +188,7 @@ def request_implementation() -> str:
|
|||
|
||||
@tool("emit_key_snippet")
|
||||
def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
||||
"""Store a single source code snippet in global memory which represents key information.
|
||||
"""Store a single source code snippet in the database which represents key information.
|
||||
Automatically adds the filepath of the snippet to related files.
|
||||
|
||||
This is for **existing**, or **just-written** files, not for things to be created in the future.
|
||||
|
|
@ -217,14 +215,8 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
|||
description=snippet_info["description"],
|
||||
)
|
||||
|
||||
# For backward compatibility, also store in global memory
|
||||
if "key_snippets" not in _global_memory:
|
||||
_global_memory["key_snippets"] = {}
|
||||
|
||||
# Use id_counter for compatibility with tests
|
||||
snippet_id = _global_memory["key_snippet_id_counter"]
|
||||
_global_memory["key_snippet_id_counter"] += 1
|
||||
_global_memory["key_snippets"][snippet_id] = snippet_info
|
||||
# Get the snippet ID from the database record
|
||||
snippet_id = key_snippet.id
|
||||
|
||||
# Format display text as markdown
|
||||
display_text = [
|
||||
|
|
@ -255,7 +247,7 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
|||
|
||||
@tool("delete_key_snippets")
|
||||
def delete_key_snippets(snippet_ids: List[int]) -> str:
|
||||
"""Delete multiple key snippets from global memory by their IDs.
|
||||
"""Delete multiple key snippets from the database by their IDs.
|
||||
Silently skips any IDs that don't exist.
|
||||
|
||||
Args:
|
||||
|
|
@ -263,20 +255,13 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
|||
"""
|
||||
results = []
|
||||
for snippet_id in snippet_ids:
|
||||
# Try to delete from database first
|
||||
# Get the snippet first to capture filepath for the message
|
||||
snippet = key_snippet_repository.get(snippet_id)
|
||||
if snippet:
|
||||
filepath = snippet.filepath
|
||||
# Delete from database
|
||||
success = key_snippet_repository.delete(snippet_id)
|
||||
|
||||
# For backward compatibility, also delete from global memory
|
||||
if snippet_id in _global_memory["key_snippets"]:
|
||||
deleted_snippet = _global_memory["key_snippets"].pop(snippet_id)
|
||||
filepath = deleted_snippet['filepath']
|
||||
else:
|
||||
# If not in memory but successful database delete, use generic message
|
||||
if success:
|
||||
filepath = "database"
|
||||
else:
|
||||
continue # Skip if not found in either place
|
||||
|
||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
|
||||
console.print(
|
||||
Panel(
|
||||
|
|
@ -606,8 +591,8 @@ def get_memory_value(key: str) -> str:
|
|||
"""
|
||||
Get a value from global memory.
|
||||
|
||||
Note: Key facts and key snippets are now handled by their respective repositories
|
||||
and formatter modules, but this function maintains backward compatibility.
|
||||
Note: Key facts and key snippets are handled by their respective repositories
|
||||
and formatter modules.
|
||||
|
||||
Different memory types return different formats:
|
||||
- key_snippets: Returns formatted snippets with file path, line number and content
|
||||
|
|
@ -623,61 +608,12 @@ def get_memory_value(key: str) -> str:
|
|||
"""
|
||||
if key == "key_snippets":
|
||||
try:
|
||||
# Try to get snippets from repository first
|
||||
# Get snippets from repository
|
||||
snippets_dict = key_snippet_repository.get_snippets_dict()
|
||||
if snippets_dict:
|
||||
return key_snippets_formatter.format_key_snippets_dict(snippets_dict)
|
||||
|
||||
# Fallback to global memory for backward compatibility
|
||||
values = _global_memory.get(key, {})
|
||||
if not values:
|
||||
return ""
|
||||
# Format each snippet with file info and content using markdown
|
||||
snippets = []
|
||||
for k, v in sorted(values.items()):
|
||||
snippet_text = [
|
||||
f"## 📝 Code Snippet #{k}",
|
||||
"", # Empty line for better markdown spacing
|
||||
"**Source Location**:",
|
||||
f"- File: `{v['filepath']}`",
|
||||
f"- Line: `{v['line_number']}`",
|
||||
"", # Empty line before code block
|
||||
"**Code**:",
|
||||
"```python",
|
||||
v["snippet"].rstrip(), # Remove trailing whitespace
|
||||
"```",
|
||||
]
|
||||
if v["description"]:
|
||||
# Add empty line and description
|
||||
snippet_text.extend(["", "**Description**:", v["description"]])
|
||||
snippets.append("\n".join(snippet_text))
|
||||
return "\n\n".join(snippets)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving key snippets: {str(e)}")
|
||||
# If there's an error with the repository, fall back to global memory
|
||||
values = _global_memory.get(key, {})
|
||||
if not values:
|
||||
return ""
|
||||
# (Same formatting code as above)
|
||||
snippets = []
|
||||
for k, v in sorted(values.items()):
|
||||
snippet_text = [
|
||||
f"## 📝 Code Snippet #{k}",
|
||||
"", # Empty line for better markdown spacing
|
||||
"**Source Location**:",
|
||||
f"- File: `{v['filepath']}`",
|
||||
f"- Line: `{v['line_number']}`",
|
||||
"", # Empty line before code block
|
||||
"**Code**:",
|
||||
"```python",
|
||||
v["snippet"].rstrip(), # Remove trailing whitespace
|
||||
"```",
|
||||
]
|
||||
if v["description"]:
|
||||
# Add empty line and description
|
||||
snippet_text.extend(["", "**Description**:", v["description"]])
|
||||
snippets.append("\n".join(snippet_text))
|
||||
return "\n\n".join(snippets)
|
||||
|
||||
if key == "work_log":
|
||||
values = _global_memory.get(key, [])
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from ra_aid.tools.memory import (
|
|||
get_related_files,
|
||||
get_work_log,
|
||||
key_fact_repository,
|
||||
key_snippet_repository,
|
||||
log_work_event,
|
||||
reset_work_log,
|
||||
swap_task_order,
|
||||
|
|
@ -30,8 +31,6 @@ def reset_memory():
|
|||
"""Reset global memory before each test"""
|
||||
_global_memory["key_facts"] = {}
|
||||
_global_memory["key_fact_id_counter"] = 0
|
||||
_global_memory["key_snippets"] = {}
|
||||
_global_memory["key_snippet_id_counter"] = 0
|
||||
_global_memory["research_notes"] = []
|
||||
_global_memory["plans"] = []
|
||||
_global_memory["tasks"] = {}
|
||||
|
|
@ -43,8 +42,6 @@ def reset_memory():
|
|||
# Clean up after test
|
||||
_global_memory["key_facts"] = {}
|
||||
_global_memory["key_fact_id_counter"] = 0
|
||||
_global_memory["key_snippets"] = {}
|
||||
_global_memory["key_snippet_id_counter"] = 0
|
||||
_global_memory["research_notes"] = []
|
||||
_global_memory["plans"] = []
|
||||
_global_memory["tasks"] = {}
|
||||
|
|
@ -113,6 +110,66 @@ def mock_repository():
|
|||
yield mock_repo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_key_snippet_repository():
|
||||
"""Mock the KeySnippetRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.tools.memory.key_snippet_repository') as mock_repo:
|
||||
# Setup the mock repository to behave like the original, but using memory
|
||||
snippets = {} # Local in-memory storage
|
||||
snippet_id_counter = 0
|
||||
|
||||
# Mock KeySnippet objects
|
||||
class MockKeySnippet:
|
||||
def __init__(self, id, filepath, line_number, snippet, description=None):
|
||||
self.id = id
|
||||
self.filepath = filepath
|
||||
self.line_number = line_number
|
||||
self.snippet = snippet
|
||||
self.description = description
|
||||
|
||||
# Mock create method
|
||||
def mock_create(filepath, line_number, snippet, description=None):
|
||||
nonlocal snippet_id_counter
|
||||
key_snippet = MockKeySnippet(snippet_id_counter, filepath, line_number, snippet, description)
|
||||
snippets[snippet_id_counter] = key_snippet
|
||||
snippet_id_counter += 1
|
||||
return key_snippet
|
||||
mock_repo.create.side_effect = mock_create
|
||||
|
||||
# Mock get method
|
||||
def mock_get(snippet_id):
|
||||
return snippets.get(snippet_id)
|
||||
mock_repo.get.side_effect = mock_get
|
||||
|
||||
# Mock delete method
|
||||
def mock_delete(snippet_id):
|
||||
if snippet_id in snippets:
|
||||
del snippets[snippet_id]
|
||||
return True
|
||||
return False
|
||||
mock_repo.delete.side_effect = mock_delete
|
||||
|
||||
# Mock get_snippets_dict method
|
||||
def mock_get_snippets_dict():
|
||||
return {
|
||||
snippet_id: {
|
||||
"filepath": snippet.filepath,
|
||||
"line_number": snippet.line_number,
|
||||
"snippet": snippet.snippet,
|
||||
"description": snippet.description
|
||||
}
|
||||
for snippet_id, snippet in snippets.items()
|
||||
}
|
||||
mock_repo.get_snippets_dict.side_effect = mock_get_snippets_dict
|
||||
|
||||
# Mock get_all method
|
||||
def mock_get_all():
|
||||
return list(snippets.values())
|
||||
mock_repo.get_all.side_effect = mock_get_all
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
||||
def test_emit_key_facts_single_fact(reset_memory, mock_repository):
|
||||
"""Test emitting a single key fact using emit_key_facts"""
|
||||
# Test with single fact
|
||||
|
|
@ -238,7 +295,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
|
|||
mock_repository.get_all.assert_called_once()
|
||||
|
||||
|
||||
def test_emit_key_snippet(reset_memory):
|
||||
def test_emit_key_snippet(reset_memory, mock_key_snippet_repository):
|
||||
"""Test emitting a single code snippet"""
|
||||
# Test snippet with description
|
||||
snippet = {
|
||||
|
|
@ -254,11 +311,13 @@ def test_emit_key_snippet(reset_memory):
|
|||
# Verify return message
|
||||
assert result == "Snippet #0 stored."
|
||||
|
||||
# Verify snippet stored correctly
|
||||
assert _global_memory["key_snippets"][0] == snippet
|
||||
|
||||
# Verify counter incremented correctly
|
||||
assert _global_memory["key_snippet_id_counter"] == 1
|
||||
# Verify create was called correctly
|
||||
mock_key_snippet_repository.create.assert_called_with(
|
||||
filepath="test.py",
|
||||
line_number=10,
|
||||
snippet="def test():\n pass",
|
||||
description="Test function"
|
||||
)
|
||||
|
||||
# Test snippet without description
|
||||
snippet2 = {
|
||||
|
|
@ -274,16 +333,18 @@ def test_emit_key_snippet(reset_memory):
|
|||
# Verify return message
|
||||
assert result == "Snippet #1 stored."
|
||||
|
||||
# Verify snippet stored correctly
|
||||
assert _global_memory["key_snippets"][1] == snippet2
|
||||
|
||||
# Verify counter incremented correctly
|
||||
assert _global_memory["key_snippet_id_counter"] == 2
|
||||
# Verify create was called correctly
|
||||
mock_key_snippet_repository.create.assert_called_with(
|
||||
filepath="main.py",
|
||||
line_number=20,
|
||||
snippet="print('hello')",
|
||||
description=None
|
||||
)
|
||||
|
||||
|
||||
def test_delete_key_snippets(reset_memory):
|
||||
def test_delete_key_snippets(reset_memory, mock_key_snippet_repository):
|
||||
"""Test deleting multiple code snippets"""
|
||||
# Add test snippets
|
||||
# Mock snippets
|
||||
snippets = [
|
||||
{
|
||||
"filepath": "test1.py",
|
||||
|
|
@ -308,20 +369,28 @@ def test_delete_key_snippets(reset_memory):
|
|||
for snippet in snippets:
|
||||
emit_key_snippet.invoke({"snippet_info": snippet})
|
||||
|
||||
# Reset mock to clear call history
|
||||
mock_key_snippet_repository.reset_mock()
|
||||
|
||||
# Test deleting mix of valid and invalid IDs
|
||||
result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]})
|
||||
|
||||
# Verify success message
|
||||
assert result == "Snippets deleted."
|
||||
|
||||
# Verify correct snippets removed
|
||||
assert 0 not in _global_memory["key_snippets"]
|
||||
assert 1 not in _global_memory["key_snippets"]
|
||||
assert 2 in _global_memory["key_snippets"]
|
||||
assert _global_memory["key_snippets"][2]["filepath"] == "test3.py"
|
||||
# Verify repository delete was called with correct IDs
|
||||
mock_key_snippet_repository.get.assert_any_call(0)
|
||||
mock_key_snippet_repository.get.assert_any_call(1)
|
||||
mock_key_snippet_repository.get.assert_any_call(999)
|
||||
|
||||
mock_key_snippet_repository.delete.assert_any_call(0)
|
||||
mock_key_snippet_repository.delete.assert_any_call(1)
|
||||
|
||||
# Make sure delete wasn't called for ID 999
|
||||
assert mock_key_snippet_repository.delete.call_count == 2
|
||||
|
||||
|
||||
def test_delete_key_snippets_empty(reset_memory):
|
||||
def test_delete_key_snippets_empty(reset_memory, mock_key_snippet_repository):
|
||||
"""Test deleting snippets with empty ID list"""
|
||||
# Add a test snippet
|
||||
snippet = {
|
||||
|
|
@ -332,12 +401,15 @@ def test_delete_key_snippets_empty(reset_memory):
|
|||
}
|
||||
emit_key_snippet.invoke({"snippet_info": snippet})
|
||||
|
||||
# Reset mock to clear call history
|
||||
mock_key_snippet_repository.reset_mock()
|
||||
|
||||
# Test with empty list
|
||||
result = delete_key_snippets.invoke({"snippet_ids": []})
|
||||
assert result == "Snippets deleted."
|
||||
|
||||
# Verify snippet still exists
|
||||
assert 0 in _global_memory["key_snippets"]
|
||||
# Verify no call to delete method
|
||||
mock_key_snippet_repository.delete.assert_not_called()
|
||||
|
||||
|
||||
def test_emit_related_files_basic(reset_memory, tmp_path):
|
||||
|
|
@ -541,7 +613,7 @@ def test_emit_related_files_path_normalization(reset_memory, tmp_path):
|
|||
os.chdir(original_dir)
|
||||
|
||||
|
||||
def test_key_snippets_integration(reset_memory, tmp_path):
|
||||
def test_key_snippets_integration(reset_memory, tmp_path, mock_key_snippet_repository):
|
||||
"""Integration test for key snippets functionality"""
|
||||
# Create test files
|
||||
file1 = tmp_path / "file1.py"
|
||||
|
|
@ -577,7 +649,7 @@ def test_key_snippets_integration(reset_memory, tmp_path):
|
|||
for i, snippet in enumerate(snippets):
|
||||
result = emit_key_snippet.invoke({"snippet_info": snippet})
|
||||
assert result == f"Snippet #{i} stored."
|
||||
assert _global_memory["key_snippet_id_counter"] == 3
|
||||
|
||||
# Verify related files were tracked with IDs
|
||||
assert len(_global_memory["related_files"]) == 3
|
||||
# Check files are stored with proper IDs
|
||||
|
|
@ -586,25 +658,25 @@ def test_key_snippets_integration(reset_memory, tmp_path):
|
|||
assert str(file2) in file_values
|
||||
assert str(file3) in file_values
|
||||
|
||||
# Verify all snippets were stored correctly
|
||||
assert len(_global_memory["key_snippets"]) == 3
|
||||
assert _global_memory["key_snippets"][0] == snippets[0]
|
||||
assert _global_memory["key_snippets"][1] == snippets[1]
|
||||
assert _global_memory["key_snippets"][2] == snippets[2]
|
||||
# Verify repository create was called for each snippet
|
||||
assert mock_key_snippet_repository.create.call_count == 3
|
||||
|
||||
# Reset mock to clear call history
|
||||
mock_key_snippet_repository.reset_mock()
|
||||
|
||||
# Delete some but not all snippets (0 and 2)
|
||||
result = delete_key_snippets.invoke({"snippet_ids": [0, 2]})
|
||||
assert result == "Snippets deleted."
|
||||
|
||||
# Verify remaining snippet is intact
|
||||
assert len(_global_memory["key_snippets"]) == 1
|
||||
assert 1 in _global_memory["key_snippets"]
|
||||
assert _global_memory["key_snippets"][1] == snippets[1]
|
||||
# Verify delete was called for the correct IDs
|
||||
mock_key_snippet_repository.delete.assert_any_call(0)
|
||||
mock_key_snippet_repository.delete.assert_any_call(2)
|
||||
assert mock_key_snippet_repository.delete.call_count == 2
|
||||
|
||||
# Counter should remain unchanged after deletions
|
||||
assert _global_memory["key_snippet_id_counter"] == 3
|
||||
# Reset mock again
|
||||
mock_key_snippet_repository.reset_mock()
|
||||
|
||||
# Add new snippet to verify counter continues correctly
|
||||
# Add new snippet
|
||||
file4 = tmp_path / "file4.py"
|
||||
file4.write_text("def func4():\n return False")
|
||||
new_snippet = {
|
||||
|
|
@ -615,21 +687,31 @@ def test_key_snippets_integration(reset_memory, tmp_path):
|
|||
}
|
||||
result = emit_key_snippet.invoke({"snippet_info": new_snippet})
|
||||
assert result == "Snippet #3 stored."
|
||||
assert _global_memory["key_snippet_id_counter"] == 4
|
||||
|
||||
# Verify create was called with correct params
|
||||
mock_key_snippet_repository.create.assert_called_with(
|
||||
filepath=str(file4),
|
||||
line_number=40,
|
||||
snippet="def func4():\n return False",
|
||||
description="Fourth function"
|
||||
)
|
||||
|
||||
# Verify new file was added to related files
|
||||
file_values = _global_memory["related_files"].values()
|
||||
assert str(file4) in file_values
|
||||
assert len(_global_memory["related_files"]) == 4
|
||||
|
||||
# Reset mock again
|
||||
mock_key_snippet_repository.reset_mock()
|
||||
|
||||
# Delete remaining snippets
|
||||
result = delete_key_snippets.invoke({"snippet_ids": [1, 3]})
|
||||
assert result == "Snippets deleted."
|
||||
|
||||
# Verify all snippets are gone
|
||||
assert len(_global_memory["key_snippets"]) == 0
|
||||
|
||||
# Counter should still maintain its value
|
||||
assert _global_memory["key_snippet_id_counter"] == 4
|
||||
# Verify delete was called for the correct IDs
|
||||
mock_key_snippet_repository.delete.assert_any_call(1)
|
||||
mock_key_snippet_repository.delete.assert_any_call(3)
|
||||
assert mock_key_snippet_repository.delete.call_count == 2
|
||||
|
||||
|
||||
def test_emit_task_with_id(reset_memory):
|
||||
|
|
|
|||
Loading…
Reference in New Issue