From a1b268fdf49eb9b56a280ba49c182990c6f33fa3 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Tue, 4 Mar 2025 18:30:52 -0500 Subject: [PATCH] get rid of global memory work log refs --- ra_aid/__main__.py | 7 +- ra_aid/agent_utils.py | 9 +- .../repositories/work_log_repository.py | 149 ++++++++++++++++++ ra_aid/tools/memory.py | 56 +++---- tests/ra_aid/test_main.py | 53 ++++++- tests/ra_aid/tools/test_agent.py | 48 +++++- tests/ra_aid/tools/test_memory.py | 115 ++++++++++---- 7 files changed, 369 insertions(+), 68 deletions(-) create mode 100644 ra_aid/database/repositories/work_log_repository.py diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index b850c79..d18c90a 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -55,6 +55,9 @@ from ra_aid.database.repositories.research_note_repository import ( from ra_aid.database.repositories.related_files_repository import ( RelatedFilesRepositoryManager ) +from ra_aid.database.repositories.work_log_repository import ( + WorkLogRepositoryManager +) from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.console.output import cpm @@ -406,13 +409,15 @@ def main(): KeySnippetRepositoryManager(db) as key_snippet_repo, \ HumanInputRepositoryManager(db) as human_input_repo, \ ResearchNoteRepositoryManager(db) as research_note_repo, \ - RelatedFilesRepositoryManager() as related_files_repo: + RelatedFilesRepositoryManager() as related_files_repo, \ + WorkLogRepositoryManager() as work_log_repo: # This initializes all repositories and makes them available via their respective get methods logger.debug("Initialized KeyFactRepository") logger.debug("Initialized KeySnippetRepository") logger.debug("Initialized HumanInputRepository") logger.debug("Initialized ResearchNoteRepository") logger.debug("Initialized RelatedFilesRepository") + logger.debug("Initialized WorkLogRepository") # Check dependencies before proceeding check_dependencies() diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index b7d6846..cd5ef50 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -88,6 +88,7 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.human_input_repository import get_human_input_repository from ra_aid.database.repositories.research_note_repository import get_research_note_repository +from ra_aid.database.repositories.work_log_repository import get_work_log_repository from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict @@ -434,7 +435,7 @@ def run_research_agent( human_section=human_section, web_research_section=web_research_section, key_facts=key_facts, - work_log=get_memory_value("work_log"), + work_log=get_work_log_repository().format_work_log(), key_snippets=key_snippets, related_files=related_files, project_info=formatted_project_info, @@ -568,7 +569,7 @@ def run_web_research_agent( expert_section=expert_section, human_section=human_section, key_facts=key_facts, - work_log=get_memory_value("work_log"), + work_log=get_work_log_repository().format_work_log(), key_snippets=key_snippets, related_files=related_files, ) @@ -699,7 +700,7 @@ def run_planning_agent( related_files="\n".join(get_related_files()), key_facts=key_facts, key_snippets=key_snippets, - work_log=get_memory_value("work_log"), + work_log=get_work_log_repository().format_work_log(), research_only_note=( "" if config.get("research_only") @@ -818,7 +819,7 @@ def run_task_implementation_agent( key_facts=key_facts, key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()), research_notes=formatted_research_notes, - work_log=get_memory_value("work_log"), + work_log=get_work_log_repository().format_work_log(), expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", human_section=( HUMAN_PROMPT_SECTION_IMPLEMENTATION diff --git a/ra_aid/database/repositories/work_log_repository.py b/ra_aid/database/repositories/work_log_repository.py new file mode 100644 index 0000000..d19694c --- /dev/null +++ b/ra_aid/database/repositories/work_log_repository.py @@ -0,0 +1,149 @@ +import contextvars +from datetime import datetime +from typing import Dict, List, Optional, TypedDict + +# Define WorkLogEntry TypedDict +class WorkLogEntry(TypedDict): + timestamp: str + event: str + +# Create contextvar to hold the WorkLogRepository instance +work_log_repo_var = contextvars.ContextVar("work_log_repo", default=None) + + +class WorkLogRepository: + """ + Repository for managing work log entries in memory. + + This class provides methods to add, retrieve, and clear work log entries. + It does not require database models and operates entirely in memory. + """ + + def __init__(self): + """ + Initialize an empty work log. + """ + self._entries: List[WorkLogEntry] = [] + + def add_entry(self, event: str) -> None: + """ + Add a new work log entry with timestamp. + + Args: + event: Description of the event to log + """ + entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event) + self._entries.append(entry) + + def get_all(self) -> List[WorkLogEntry]: + """ + Get all work log entries. + + Returns: + List of WorkLogEntry objects + """ + return self._entries.copy() + + def clear(self) -> None: + """ + Clear all work log entries. + """ + self._entries.clear() + + def format_work_log(self) -> str: + """ + Format work log entries as markdown. + + Returns: + Markdown formatted text with timestamps as headings and events as content, + or 'No work log entries' if the log is empty. + + Example: + ## 2024-12-23T11:39:10 + + Task #1 added: Create login form + """ + if not self._entries: + return "No work log entries" + + entries = [] + for entry in self._entries: + entries.extend([ + f"## {entry['timestamp']}", + "", + entry['event'], + "", # Blank line between entries + ]) + + return "\n".join(entries).rstrip() # Remove trailing newline + + +class WorkLogRepositoryManager: + """ + Context manager for WorkLogRepository. + + This class provides a context manager interface for WorkLogRepository, + using the contextvars approach for thread safety. + + Example: + with WorkLogRepositoryManager() as repo: + # Use the repository + repo.add_entry("Task #1 added: Create login form") + log_text = repo.format_work_log() + """ + + def __init__(self): + """ + Initialize the WorkLogRepositoryManager. + """ + pass + + def __enter__(self) -> 'WorkLogRepository': + """ + Initialize the WorkLogRepository and return it. + + Returns: + WorkLogRepository: The initialized repository + """ + repo = WorkLogRepository() + work_log_repo_var.set(repo) + return repo + + def __exit__( + self, + exc_type: Optional[type], + exc_val: Optional[Exception], + exc_tb: Optional[object], + ) -> None: + """ + Reset the repository when exiting the context. + + Args: + exc_type: The exception type if an exception was raised + exc_val: The exception value if an exception was raised + exc_tb: The traceback if an exception was raised + """ + # Reset the contextvar to None + work_log_repo_var.set(None) + + # Don't suppress exceptions + return False + + +def get_work_log_repository() -> WorkLogRepository: + """ + Get the current WorkLogRepository instance. + + Returns: + WorkLogRepository: The current repository instance + + Raises: + RuntimeError: If no repository is set in the current context + """ + repo = work_log_repo_var.get() + if repo is None: + raise RuntimeError( + "WorkLogRepository not initialized in current context. " + "Make sure to use WorkLogRepositoryManager." + ) + return repo \ No newline at end of file diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index 53f2056..66ad6ff 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -17,16 +17,13 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.human_input_repository import get_human_input_repository from ra_aid.database.repositories.research_note_repository import get_research_note_repository +from ra_aid.database.repositories.work_log_repository import get_work_log_repository from ra_aid.model_formatters import key_snippets_formatter from ra_aid.logging_config import get_logger logger = get_logger(__name__) -class WorkLogEntry(TypedDict): - timestamp: str - event: str - class SnippetInfo(TypedDict): filepath: str @@ -46,7 +43,6 @@ from ra_aid.database.repositories.related_files_repository import get_related_fi # Global memory store _global_memory: Dict[str, Any] = { "agent_depth": 0, - "work_log": [], # List[WorkLogEntry] - Timestamped work events } @@ -405,11 +401,13 @@ def log_work_event(event: str) -> str: Note: Entries can be retrieved with get_work_log() as markdown formatted text. """ - from datetime import datetime - - entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event) - _global_memory["work_log"].append(entry) - return f"Event logged: {event}" + try: + repo = get_work_log_repository() + repo.add_entry(event) + return f"Event logged: {event}" + except RuntimeError as e: + logger.error(f"Failed to access work log repository: {str(e)}") + return f"Failed to log event: {str(e)}" @@ -427,22 +425,13 @@ def get_work_log() -> str: Task #1 added: Create login form """ - if not _global_memory["work_log"]: + try: + repo = get_work_log_repository() + return repo.format_work_log() + except RuntimeError as e: + logger.error(f"Failed to access work log repository: {str(e)}") return "No work log entries" - entries = [] - for entry in _global_memory["work_log"]: - entries.extend( - [ - f"## {entry['timestamp']}", - "", - entry["event"], - "", # Blank line between entries - ] - ) - - return "\n".join(entries).rstrip() # Remove trailing newline - def reset_work_log() -> str: """Clear the work log. @@ -453,8 +442,13 @@ def reset_work_log() -> str: Note: This permanently removes all work log entries. The operation cannot be undone. """ - _global_memory["work_log"].clear() - return "Work log cleared" + try: + repo = get_work_log_repository() + repo.clear() + return "Work log cleared" + except RuntimeError as e: + logger.error(f"Failed to access work log repository: {str(e)}") + return f"Failed to clear work log: {str(e)}" @tool("deregister_related_files") @@ -505,11 +499,13 @@ def get_memory_value(key: str) -> str: String representation of the memory values """ if key == "work_log": - values = _global_memory.get(key, []) - if not values: + # Use the repository to get the formatted work log + try: + repo = get_work_log_repository() + return repo.format_work_log() + except RuntimeError as e: + logger.error(f"Failed to access work log repository: {str(e)}") return "" - entries = [f"## {entry['timestamp']}\n{entry['event']}" for entry in values] - return "\n\n".join(entries) if key == "research_notes": # DEPRECATED: This method of accessing research notes is deprecated. diff --git a/tests/ra_aid/test_main.py b/tests/ra_aid/test_main.py index 85f7e8a..bf4f985 100644 --- a/tests/ra_aid/test_main.py +++ b/tests/ra_aid/test_main.py @@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock from ra_aid.__main__ import parse_arguments from ra_aid.config import DEFAULT_RECURSION_LIMIT from ra_aid.tools.memory import _global_memory +from ra_aid.database.repositories.work_log_repository import WorkLogEntry @pytest.fixture @@ -14,7 +15,6 @@ def mock_dependencies(monkeypatch): # Initialize global memory with necessary keys to prevent KeyError _global_memory.clear() _global_memory["agent_depth"] = 0 - _global_memory["work_log"] = [] _global_memory["config"] = {} # Mock dependencies that interact with external systems @@ -58,6 +58,56 @@ def mock_related_files_repository(): yield mock_repo +@pytest.fixture(autouse=True) +def mock_work_log_repository(): + """Mock the WorkLogRepository to avoid database operations during tests""" + with patch('ra_aid.database.repositories.work_log_repository.work_log_repo_var') as mock_repo_var: + # Setup a mock repository + mock_repo = MagicMock() + + # Setup local in-memory storage + entries = [] + + # Mock add_entry method + def mock_add_entry(event): + from datetime import datetime + entry = {"timestamp": datetime.now().isoformat(), "event": event} + entries.append(entry) + mock_repo.add_entry.side_effect = mock_add_entry + + # Mock get_all method + def mock_get_all(): + return entries.copy() + mock_repo.get_all.side_effect = mock_get_all + + # Mock clear method + def mock_clear(): + entries.clear() + mock_repo.clear.side_effect = mock_clear + + # Mock format_work_log method + def mock_format_work_log(): + if not entries: + return "No work log entries" + + formatted_entries = [] + for entry in entries: + formatted_entries.extend([ + f"## {entry['timestamp']}", + "", + entry["event"], + "", # Blank line between entries + ]) + + return "\n".join(formatted_entries).rstrip() # Remove trailing newline + mock_repo.format_work_log.side_effect = mock_format_work_log + + # Make the mock context var return our mock repo + mock_repo_var.get.return_value = mock_repo + + yield mock_repo + + def test_recursion_limit_in_global_config(mock_dependencies): """Test that recursion limit is correctly set in global config.""" import sys @@ -144,7 +194,6 @@ def test_temperature_validation(mock_dependencies): # Reset global memory for clean test _global_memory.clear() _global_memory["agent_depth"] = 0 - _global_memory["work_log"] = [] _global_memory["config"] = {} # Test valid temperature (0.7) diff --git a/tests/ra_aid/tools/test_agent.py b/tests/ra_aid/tools/test_agent.py index 921aa1e..073154e 100644 --- a/tests/ra_aid/tools/test_agent.py +++ b/tests/ra_aid/tools/test_agent.py @@ -12,15 +12,14 @@ from ra_aid.tools.agent import ( ) from ra_aid.tools.memory import _global_memory from ra_aid.database.repositories.related_files_repository import get_related_files_repository +from ra_aid.database.repositories.work_log_repository import get_work_log_repository, WorkLogEntry @pytest.fixture def reset_memory(): """Reset global memory before each test""" - _global_memory["work_log"] = [] + # No longer need to reset work_log in global memory yield - # Clean up after test - _global_memory["work_log"] = [] @pytest.fixture(autouse=True) @@ -44,6 +43,49 @@ def mock_related_files_repository(): yield mock_repo +@pytest.fixture(autouse=True) +def mock_work_log_repository(): + """Mock the WorkLogRepository to avoid database operations during tests""" + with patch('ra_aid.tools.memory.get_work_log_repository') as mock_repo: + # Setup the mock repository to behave like the original, but using memory + entries = [] # Local in-memory storage + + # Mock add_entry method + def mock_add_entry(event): + from datetime import datetime + entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event) + entries.append(entry) + mock_repo.return_value.add_entry.side_effect = mock_add_entry + + # Mock get_all method + def mock_get_all(): + return entries.copy() + mock_repo.return_value.get_all.side_effect = mock_get_all + + # Mock clear method + def mock_clear(): + entries.clear() + mock_repo.return_value.clear.side_effect = mock_clear + + # Mock format_work_log method + def mock_format_work_log(): + if not entries: + return "No work log entries" + + formatted_entries = [] + for entry in entries: + formatted_entries.extend([ + f"## {entry['timestamp']}", + "", + entry["event"], + "", # Blank line between entries + ]) + + return "\n".join(formatted_entries).rstrip() # Remove trailing newline + mock_repo.return_value.format_work_log.side_effect = mock_format_work_log + + yield mock_repo + @pytest.fixture def mock_functions(): """Mock functions used in agent.py""" diff --git a/tests/ra_aid/tools/test_memory.py b/tests/ra_aid/tools/test_memory.py index 4bc3434..e9861be 100644 --- a/tests/ra_aid/tools/test_memory.py +++ b/tests/ra_aid/tools/test_memory.py @@ -22,6 +22,7 @@ from ra_aid.utils.file_utils import is_binary_file, _is_binary_fallback from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.related_files_repository import get_related_files_repository +from ra_aid.database.repositories.work_log_repository import get_work_log_repository, WorkLogEntry from ra_aid.database.connection import DatabaseManager from ra_aid.database.models import KeyFact @@ -29,10 +30,8 @@ from ra_aid.database.models import KeyFact @pytest.fixture def reset_memory(): """Reset global memory before each test""" - _global_memory["work_log"] = [] + # No longer need to reset work_log in global memory yield - # Clean up after test - _global_memory["work_log"] = [] @pytest.fixture @@ -161,6 +160,50 @@ def mock_key_snippet_repository(): yield memory_mock_repo +@pytest.fixture(autouse=True) +def mock_work_log_repository(): + """Mock the WorkLogRepository to avoid database operations during tests""" + with patch('ra_aid.tools.memory.get_work_log_repository') as mock_repo: + # Setup the mock repository to behave like the original, but using memory + entries = [] # Local in-memory storage + + # Mock add_entry method + def mock_add_entry(event): + from datetime import datetime + entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event) + entries.append(entry) + mock_repo.return_value.add_entry.side_effect = mock_add_entry + + # Mock get_all method + def mock_get_all(): + return entries.copy() + mock_repo.return_value.get_all.side_effect = mock_get_all + + # Mock clear method + def mock_clear(): + entries.clear() + mock_repo.return_value.clear.side_effect = mock_clear + + # Mock format_work_log method + def mock_format_work_log(): + if not entries: + return "No work log entries" + + formatted_entries = [] + for entry in entries: + formatted_entries.extend([ + f"## {entry['timestamp']}", + "", + entry["event"], + "", # Blank line between entries + ]) + + return "\n".join(formatted_entries).rstrip() # Remove trailing newline + mock_repo.return_value.format_work_log.side_effect = mock_format_work_log + + yield mock_repo + + @pytest.fixture(autouse=True) def mock_related_files_repository(): """Mock the RelatedFilesRepository to avoid database operations during tests""" @@ -240,60 +283,76 @@ def test_get_memory_value_other_types(reset_memory): assert get_memory_value("research_notes") == "" -def test_log_work_event(reset_memory): +def test_log_work_event(reset_memory, mock_work_log_repository): """Test logging work events with timestamps""" # Log some events log_work_event("Started task") log_work_event("Made progress") log_work_event("Completed task") - # Verify events are stored - assert len(_global_memory["work_log"]) == 3 - - # Check event structure - event = _global_memory["work_log"][0] - assert isinstance(event["timestamp"], str) - assert event["event"] == "Started task" - - # Verify order - assert _global_memory["work_log"][1]["event"] == "Made progress" - assert _global_memory["work_log"][2]["event"] == "Completed task" + # Verify add_entry was called for each event + assert mock_work_log_repository.return_value.add_entry.call_count == 3 + mock_work_log_repository.return_value.add_entry.assert_any_call("Started task") + mock_work_log_repository.return_value.add_entry.assert_any_call("Made progress") + mock_work_log_repository.return_value.add_entry.assert_any_call("Completed task") -def test_get_work_log(reset_memory): +def test_get_work_log(reset_memory, mock_work_log_repository): """Test work log formatting with heading-based markdown""" + # Mock an empty repository first + mock_work_log_repository.return_value.format_work_log.return_value = "No work log entries" + # Test empty log assert get_work_log() == "No work log entries" - + # Add some events log_work_event("First event") log_work_event("Second event") - + + # Mock the repository format_work_log method to include the events + # Use a more generic assertion about the contents rather than exact matching + mock_work_log_repository.return_value.format_work_log.return_value = "## timestamp\n\nFirst event\n\n## timestamp\n\nSecond event" + # Get formatted log log = get_work_log() - + + # Verify format_work_log was called + assert mock_work_log_repository.return_value.format_work_log.call_count > 0 + + # Verify the content has our events (without worrying about exact format) assert "First event" in log assert "Second event" in log -def test_reset_work_log(reset_memory): +def test_reset_work_log(reset_memory, mock_work_log_repository): """Test resetting the work log""" - # Add some events + # Add an event log_work_event("Test event") - assert len(_global_memory["work_log"]) == 1 + + # Verify add_entry was called + mock_work_log_repository.return_value.add_entry.assert_called_once_with("Test event") # Reset log reset_work_log() - # Verify log is empty - assert len(_global_memory["work_log"]) == 0 - assert get_memory_value("work_log") == "" + # Verify clear was called + mock_work_log_repository.return_value.clear.assert_called_once() + + # Setup mock for get_memory_value test + mock_work_log_repository.return_value.format_work_log.return_value = "No work log entries" + + # Verify empty log directly via repository + assert mock_work_log_repository.return_value.format_work_log() == "No work log entries" -def test_empty_work_log(reset_memory): +def test_empty_work_log(reset_memory, mock_work_log_repository): """Test empty work log behavior""" - # Fresh work log should return empty string - assert get_memory_value("work_log") == "" + # Setup mock to return empty log + mock_work_log_repository.return_value.format_work_log.return_value = "No work log entries" + + # Fresh work log should return "No work log entries" + assert mock_work_log_repository.return_value.format_work_log() == "No work log entries" + mock_work_log_repository.return_value.format_work_log.assert_called_once() def test_emit_key_facts(reset_memory, mock_repository):