get rid of global memory work log refs

This commit is contained in:
AI Christianson 2025-03-04 18:30:52 -05:00
parent d907a0ea9c
commit a1b268fdf4
7 changed files with 369 additions and 68 deletions

View File

@ -55,6 +55,9 @@ from ra_aid.database.repositories.research_note_repository import (
from ra_aid.database.repositories.related_files_repository import (
RelatedFilesRepositoryManager
)
from ra_aid.database.repositories.work_log_repository import (
WorkLogRepositoryManager
)
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.console.output import cpm
@ -406,13 +409,15 @@ def main():
KeySnippetRepositoryManager(db) as key_snippet_repo, \
HumanInputRepositoryManager(db) as human_input_repo, \
ResearchNoteRepositoryManager(db) as research_note_repo, \
RelatedFilesRepositoryManager() as related_files_repo:
RelatedFilesRepositoryManager() as related_files_repo, \
WorkLogRepositoryManager() as work_log_repo:
# This initializes all repositories and makes them available via their respective get methods
logger.debug("Initialized KeyFactRepository")
logger.debug("Initialized KeySnippetRepository")
logger.debug("Initialized HumanInputRepository")
logger.debug("Initialized ResearchNoteRepository")
logger.debug("Initialized RelatedFilesRepository")
logger.debug("Initialized WorkLogRepository")
# Check dependencies before proceeding
check_dependencies()

View File

@ -88,6 +88,7 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
@ -434,7 +435,7 @@ def run_research_agent(
human_section=human_section,
web_research_section=web_research_section,
key_facts=key_facts,
work_log=get_memory_value("work_log"),
work_log=get_work_log_repository().format_work_log(),
key_snippets=key_snippets,
related_files=related_files,
project_info=formatted_project_info,
@ -568,7 +569,7 @@ def run_web_research_agent(
expert_section=expert_section,
human_section=human_section,
key_facts=key_facts,
work_log=get_memory_value("work_log"),
work_log=get_work_log_repository().format_work_log(),
key_snippets=key_snippets,
related_files=related_files,
)
@ -699,7 +700,7 @@ def run_planning_agent(
related_files="\n".join(get_related_files()),
key_facts=key_facts,
key_snippets=key_snippets,
work_log=get_memory_value("work_log"),
work_log=get_work_log_repository().format_work_log(),
research_only_note=(
""
if config.get("research_only")
@ -818,7 +819,7 @@ def run_task_implementation_agent(
key_facts=key_facts,
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
research_notes=formatted_research_notes,
work_log=get_memory_value("work_log"),
work_log=get_work_log_repository().format_work_log(),
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
human_section=(
HUMAN_PROMPT_SECTION_IMPLEMENTATION

View File

@ -0,0 +1,149 @@
import contextvars
from datetime import datetime
from typing import Dict, List, Optional, TypedDict
# Define WorkLogEntry TypedDict
class WorkLogEntry(TypedDict):
timestamp: str
event: str
# Create contextvar to hold the WorkLogRepository instance
work_log_repo_var = contextvars.ContextVar("work_log_repo", default=None)
class WorkLogRepository:
"""
Repository for managing work log entries in memory.
This class provides methods to add, retrieve, and clear work log entries.
It does not require database models and operates entirely in memory.
"""
def __init__(self):
"""
Initialize an empty work log.
"""
self._entries: List[WorkLogEntry] = []
def add_entry(self, event: str) -> None:
"""
Add a new work log entry with timestamp.
Args:
event: Description of the event to log
"""
entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event)
self._entries.append(entry)
def get_all(self) -> List[WorkLogEntry]:
"""
Get all work log entries.
Returns:
List of WorkLogEntry objects
"""
return self._entries.copy()
def clear(self) -> None:
"""
Clear all work log entries.
"""
self._entries.clear()
def format_work_log(self) -> str:
"""
Format work log entries as markdown.
Returns:
Markdown formatted text with timestamps as headings and events as content,
or 'No work log entries' if the log is empty.
Example:
## 2024-12-23T11:39:10
Task #1 added: Create login form
"""
if not self._entries:
return "No work log entries"
entries = []
for entry in self._entries:
entries.extend([
f"## {entry['timestamp']}",
"",
entry['event'],
"", # Blank line between entries
])
return "\n".join(entries).rstrip() # Remove trailing newline
class WorkLogRepositoryManager:
"""
Context manager for WorkLogRepository.
This class provides a context manager interface for WorkLogRepository,
using the contextvars approach for thread safety.
Example:
with WorkLogRepositoryManager() as repo:
# Use the repository
repo.add_entry("Task #1 added: Create login form")
log_text = repo.format_work_log()
"""
def __init__(self):
"""
Initialize the WorkLogRepositoryManager.
"""
pass
def __enter__(self) -> 'WorkLogRepository':
"""
Initialize the WorkLogRepository and return it.
Returns:
WorkLogRepository: The initialized repository
"""
repo = WorkLogRepository()
work_log_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
work_log_repo_var.set(None)
# Don't suppress exceptions
return False
def get_work_log_repository() -> WorkLogRepository:
"""
Get the current WorkLogRepository instance.
Returns:
WorkLogRepository: The current repository instance
Raises:
RuntimeError: If no repository is set in the current context
"""
repo = work_log_repo_var.get()
if repo is None:
raise RuntimeError(
"WorkLogRepository not initialized in current context. "
"Make sure to use WorkLogRepositoryManager."
)
return repo

View File

@ -17,16 +17,13 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
from ra_aid.model_formatters import key_snippets_formatter
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
class WorkLogEntry(TypedDict):
timestamp: str
event: str
class SnippetInfo(TypedDict):
filepath: str
@ -46,7 +43,6 @@ from ra_aid.database.repositories.related_files_repository import get_related_fi
# Global memory store
_global_memory: Dict[str, Any] = {
"agent_depth": 0,
"work_log": [], # List[WorkLogEntry] - Timestamped work events
}
@ -405,11 +401,13 @@ def log_work_event(event: str) -> str:
Note:
Entries can be retrieved with get_work_log() as markdown formatted text.
"""
from datetime import datetime
entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event)
_global_memory["work_log"].append(entry)
return f"Event logged: {event}"
try:
repo = get_work_log_repository()
repo.add_entry(event)
return f"Event logged: {event}"
except RuntimeError as e:
logger.error(f"Failed to access work log repository: {str(e)}")
return f"Failed to log event: {str(e)}"
@ -427,22 +425,13 @@ def get_work_log() -> str:
Task #1 added: Create login form
"""
if not _global_memory["work_log"]:
try:
repo = get_work_log_repository()
return repo.format_work_log()
except RuntimeError as e:
logger.error(f"Failed to access work log repository: {str(e)}")
return "No work log entries"
entries = []
for entry in _global_memory["work_log"]:
entries.extend(
[
f"## {entry['timestamp']}",
"",
entry["event"],
"", # Blank line between entries
]
)
return "\n".join(entries).rstrip() # Remove trailing newline
def reset_work_log() -> str:
"""Clear the work log.
@ -453,8 +442,13 @@ def reset_work_log() -> str:
Note:
This permanently removes all work log entries. The operation cannot be undone.
"""
_global_memory["work_log"].clear()
return "Work log cleared"
try:
repo = get_work_log_repository()
repo.clear()
return "Work log cleared"
except RuntimeError as e:
logger.error(f"Failed to access work log repository: {str(e)}")
return f"Failed to clear work log: {str(e)}"
@tool("deregister_related_files")
@ -505,11 +499,13 @@ def get_memory_value(key: str) -> str:
String representation of the memory values
"""
if key == "work_log":
values = _global_memory.get(key, [])
if not values:
# Use the repository to get the formatted work log
try:
repo = get_work_log_repository()
return repo.format_work_log()
except RuntimeError as e:
logger.error(f"Failed to access work log repository: {str(e)}")
return ""
entries = [f"## {entry['timestamp']}\n{entry['event']}" for entry in values]
return "\n\n".join(entries)
if key == "research_notes":
# DEPRECATED: This method of accessing research notes is deprecated.

View File

@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
from ra_aid.__main__ import parse_arguments
from ra_aid.config import DEFAULT_RECURSION_LIMIT
from ra_aid.tools.memory import _global_memory
from ra_aid.database.repositories.work_log_repository import WorkLogEntry
@pytest.fixture
@ -14,7 +15,6 @@ def mock_dependencies(monkeypatch):
# Initialize global memory with necessary keys to prevent KeyError
_global_memory.clear()
_global_memory["agent_depth"] = 0
_global_memory["work_log"] = []
_global_memory["config"] = {}
# Mock dependencies that interact with external systems
@ -58,6 +58,56 @@ def mock_related_files_repository():
yield mock_repo
@pytest.fixture(autouse=True)
def mock_work_log_repository():
"""Mock the WorkLogRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.work_log_repository.work_log_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Setup local in-memory storage
entries = []
# Mock add_entry method
def mock_add_entry(event):
from datetime import datetime
entry = {"timestamp": datetime.now().isoformat(), "event": event}
entries.append(entry)
mock_repo.add_entry.side_effect = mock_add_entry
# Mock get_all method
def mock_get_all():
return entries.copy()
mock_repo.get_all.side_effect = mock_get_all
# Mock clear method
def mock_clear():
entries.clear()
mock_repo.clear.side_effect = mock_clear
# Mock format_work_log method
def mock_format_work_log():
if not entries:
return "No work log entries"
formatted_entries = []
for entry in entries:
formatted_entries.extend([
f"## {entry['timestamp']}",
"",
entry["event"],
"", # Blank line between entries
])
return "\n".join(formatted_entries).rstrip() # Remove trailing newline
mock_repo.format_work_log.side_effect = mock_format_work_log
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
def test_recursion_limit_in_global_config(mock_dependencies):
"""Test that recursion limit is correctly set in global config."""
import sys
@ -144,7 +194,6 @@ def test_temperature_validation(mock_dependencies):
# Reset global memory for clean test
_global_memory.clear()
_global_memory["agent_depth"] = 0
_global_memory["work_log"] = []
_global_memory["config"] = {}
# Test valid temperature (0.7)

View File

@ -12,15 +12,14 @@ from ra_aid.tools.agent import (
)
from ra_aid.tools.memory import _global_memory
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
from ra_aid.database.repositories.work_log_repository import get_work_log_repository, WorkLogEntry
@pytest.fixture
def reset_memory():
"""Reset global memory before each test"""
_global_memory["work_log"] = []
# No longer need to reset work_log in global memory
yield
# Clean up after test
_global_memory["work_log"] = []
@pytest.fixture(autouse=True)
@ -44,6 +43,49 @@ def mock_related_files_repository():
yield mock_repo
@pytest.fixture(autouse=True)
def mock_work_log_repository():
"""Mock the WorkLogRepository to avoid database operations during tests"""
with patch('ra_aid.tools.memory.get_work_log_repository') as mock_repo:
# Setup the mock repository to behave like the original, but using memory
entries = [] # Local in-memory storage
# Mock add_entry method
def mock_add_entry(event):
from datetime import datetime
entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event)
entries.append(entry)
mock_repo.return_value.add_entry.side_effect = mock_add_entry
# Mock get_all method
def mock_get_all():
return entries.copy()
mock_repo.return_value.get_all.side_effect = mock_get_all
# Mock clear method
def mock_clear():
entries.clear()
mock_repo.return_value.clear.side_effect = mock_clear
# Mock format_work_log method
def mock_format_work_log():
if not entries:
return "No work log entries"
formatted_entries = []
for entry in entries:
formatted_entries.extend([
f"## {entry['timestamp']}",
"",
entry["event"],
"", # Blank line between entries
])
return "\n".join(formatted_entries).rstrip() # Remove trailing newline
mock_repo.return_value.format_work_log.side_effect = mock_format_work_log
yield mock_repo
@pytest.fixture
def mock_functions():
"""Mock functions used in agent.py"""

View File

@ -22,6 +22,7 @@ from ra_aid.utils.file_utils import is_binary_file, _is_binary_fallback
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
from ra_aid.database.repositories.work_log_repository import get_work_log_repository, WorkLogEntry
from ra_aid.database.connection import DatabaseManager
from ra_aid.database.models import KeyFact
@ -29,10 +30,8 @@ from ra_aid.database.models import KeyFact
@pytest.fixture
def reset_memory():
"""Reset global memory before each test"""
_global_memory["work_log"] = []
# No longer need to reset work_log in global memory
yield
# Clean up after test
_global_memory["work_log"] = []
@pytest.fixture
@ -161,6 +160,50 @@ def mock_key_snippet_repository():
yield memory_mock_repo
@pytest.fixture(autouse=True)
def mock_work_log_repository():
"""Mock the WorkLogRepository to avoid database operations during tests"""
with patch('ra_aid.tools.memory.get_work_log_repository') as mock_repo:
# Setup the mock repository to behave like the original, but using memory
entries = [] # Local in-memory storage
# Mock add_entry method
def mock_add_entry(event):
from datetime import datetime
entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event)
entries.append(entry)
mock_repo.return_value.add_entry.side_effect = mock_add_entry
# Mock get_all method
def mock_get_all():
return entries.copy()
mock_repo.return_value.get_all.side_effect = mock_get_all
# Mock clear method
def mock_clear():
entries.clear()
mock_repo.return_value.clear.side_effect = mock_clear
# Mock format_work_log method
def mock_format_work_log():
if not entries:
return "No work log entries"
formatted_entries = []
for entry in entries:
formatted_entries.extend([
f"## {entry['timestamp']}",
"",
entry["event"],
"", # Blank line between entries
])
return "\n".join(formatted_entries).rstrip() # Remove trailing newline
mock_repo.return_value.format_work_log.side_effect = mock_format_work_log
yield mock_repo
@pytest.fixture(autouse=True)
def mock_related_files_repository():
"""Mock the RelatedFilesRepository to avoid database operations during tests"""
@ -240,28 +283,25 @@ def test_get_memory_value_other_types(reset_memory):
assert get_memory_value("research_notes") == ""
def test_log_work_event(reset_memory):
def test_log_work_event(reset_memory, mock_work_log_repository):
"""Test logging work events with timestamps"""
# Log some events
log_work_event("Started task")
log_work_event("Made progress")
log_work_event("Completed task")
# Verify events are stored
assert len(_global_memory["work_log"]) == 3
# Check event structure
event = _global_memory["work_log"][0]
assert isinstance(event["timestamp"], str)
assert event["event"] == "Started task"
# Verify order
assert _global_memory["work_log"][1]["event"] == "Made progress"
assert _global_memory["work_log"][2]["event"] == "Completed task"
# Verify add_entry was called for each event
assert mock_work_log_repository.return_value.add_entry.call_count == 3
mock_work_log_repository.return_value.add_entry.assert_any_call("Started task")
mock_work_log_repository.return_value.add_entry.assert_any_call("Made progress")
mock_work_log_repository.return_value.add_entry.assert_any_call("Completed task")
def test_get_work_log(reset_memory):
def test_get_work_log(reset_memory, mock_work_log_repository):
"""Test work log formatting with heading-based markdown"""
# Mock an empty repository first
mock_work_log_repository.return_value.format_work_log.return_value = "No work log entries"
# Test empty log
assert get_work_log() == "No work log entries"
@ -269,31 +309,50 @@ def test_get_work_log(reset_memory):
log_work_event("First event")
log_work_event("Second event")
# Mock the repository format_work_log method to include the events
# Use a more generic assertion about the contents rather than exact matching
mock_work_log_repository.return_value.format_work_log.return_value = "## timestamp\n\nFirst event\n\n## timestamp\n\nSecond event"
# Get formatted log
log = get_work_log()
# Verify format_work_log was called
assert mock_work_log_repository.return_value.format_work_log.call_count > 0
# Verify the content has our events (without worrying about exact format)
assert "First event" in log
assert "Second event" in log
def test_reset_work_log(reset_memory):
def test_reset_work_log(reset_memory, mock_work_log_repository):
"""Test resetting the work log"""
# Add some events
# Add an event
log_work_event("Test event")
assert len(_global_memory["work_log"]) == 1
# Verify add_entry was called
mock_work_log_repository.return_value.add_entry.assert_called_once_with("Test event")
# Reset log
reset_work_log()
# Verify log is empty
assert len(_global_memory["work_log"]) == 0
assert get_memory_value("work_log") == ""
# Verify clear was called
mock_work_log_repository.return_value.clear.assert_called_once()
# Setup mock for get_memory_value test
mock_work_log_repository.return_value.format_work_log.return_value = "No work log entries"
# Verify empty log directly via repository
assert mock_work_log_repository.return_value.format_work_log() == "No work log entries"
def test_empty_work_log(reset_memory):
def test_empty_work_log(reset_memory, mock_work_log_repository):
"""Test empty work log behavior"""
# Fresh work log should return empty string
assert get_memory_value("work_log") == ""
# Setup mock to return empty log
mock_work_log_repository.return_value.format_work_log.return_value = "No work log entries"
# Fresh work log should return "No work log entries"
assert mock_work_log_repository.return_value.format_work_log() == "No work log entries"
mock_work_log_repository.return_value.format_work_log.assert_called_once()
def test_emit_key_facts(reset_memory, mock_repository):