use repo pattern

This commit is contained in:
AI Christianson 2025-03-04 14:59:45 -05:00
parent 0afed55809
commit 600bf355d9
15 changed files with 559 additions and 417 deletions

View File

@ -52,6 +52,9 @@ from ra_aid.database.repositories.human_input_repository import (
from ra_aid.database.repositories.research_note_repository import (
ResearchNoteRepositoryManager, get_research_note_repository
)
from ra_aid.database.repositories.related_files_repository import (
RelatedFilesRepositoryManager
)
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.console.output import cpm
@ -402,12 +405,14 @@ def main():
with KeyFactRepositoryManager(db) as key_fact_repo, \
KeySnippetRepositoryManager(db) as key_snippet_repo, \
HumanInputRepositoryManager(db) as human_input_repo, \
ResearchNoteRepositoryManager(db) as research_note_repo:
ResearchNoteRepositoryManager(db) as research_note_repo, \
RelatedFilesRepositoryManager() as related_files_repo:
# This initializes all repositories and makes them available via their respective get methods
logger.debug("Initialized KeyFactRepository")
logger.debug("Initialized KeySnippetRepository")
logger.debug("Initialized HumanInputRepository")
logger.debug("Initialized ResearchNoteRepository")
logger.debug("Initialized RelatedFilesRepository")
# Check dependencies before proceeding
check_dependencies()

View File

@ -397,7 +397,7 @@ def run_research_agent(
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
code_snippets = _global_memory.get("code_snippets", "")
related_files = _global_memory.get("related_files", "")
related_files = get_related_files()
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
working_directory = os.getcwd()
@ -552,7 +552,7 @@ def run_web_research_agent(
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
code_snippets = _global_memory.get("code_snippets", "")
related_files = _global_memory.get("related_files", "")
related_files = get_related_files()
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
working_directory = os.getcwd()

View File

@ -20,6 +20,11 @@ from ra_aid.database.repositories.key_snippet_repository import (
KeySnippetRepositoryManager,
get_key_snippet_repository
)
from ra_aid.database.repositories.related_files_repository import (
RelatedFilesRepository,
RelatedFilesRepositoryManager,
get_related_files_repository
)
from ra_aid.database.repositories.research_note_repository import (
ResearchNoteRepository,
ResearchNoteRepositoryManager,
@ -36,6 +41,9 @@ __all__ = [
'KeySnippetRepository',
'KeySnippetRepositoryManager',
'get_key_snippet_repository',
'RelatedFilesRepository',
'RelatedFilesRepositoryManager',
'get_related_files_repository',
'ResearchNoteRepository',
'ResearchNoteRepositoryManager',
'get_research_note_repository',

View File

@ -0,0 +1,169 @@
import contextvars
import os
from typing import Dict, List, Optional
# Import is_binary_file from memory.py
from ra_aid.utils.file_utils import is_binary_file
# Create contextvar to hold the RelatedFilesRepository instance
related_files_repo_var = contextvars.ContextVar("related_files_repo", default=None)
class RelatedFilesRepository:
"""
Repository for managing related files in memory.
This class provides methods to add, remove, and retrieve related files.
It does not require database models and operates entirely in memory.
"""
def __init__(self):
"""
Initialize the RelatedFilesRepository.
"""
self._related_files: Dict[int, str] = {}
self._id_counter: int = 1
def get_all(self) -> Dict[int, str]:
"""
Get all related files.
Returns:
Dict[int, str]: Dictionary mapping file IDs to file paths
"""
return self._related_files.copy()
def add_file(self, filepath: str) -> Optional[int]:
"""
Add a file to the repository.
Args:
filepath: Path to the file to add
Returns:
Optional[int]: The ID assigned to the file, or None if the file could not be added
"""
# First check if path exists
if not os.path.exists(filepath):
return None
# Then check if it's a directory
if os.path.isdir(filepath):
return None
# Validate it's a regular file
if not os.path.isfile(filepath):
return None
# Check if it's a binary file
if is_binary_file(filepath):
return None
# Normalize the path
normalized_path = os.path.abspath(filepath)
# Check if normalized path already exists in values
for file_id, path in self._related_files.items():
if path == normalized_path:
return file_id
# Add new file
file_id = self._id_counter
self._id_counter += 1
self._related_files[file_id] = normalized_path
return file_id
def remove_file(self, file_id: int) -> Optional[str]:
"""
Remove a file from the repository.
Args:
file_id: ID of the file to remove
Returns:
Optional[str]: The path of the removed file, or None if the file ID was not found
"""
if file_id in self._related_files:
return self._related_files.pop(file_id)
return None
def format_related_files(self) -> List[str]:
"""
Format related files as 'ID#X path/to/file'.
Returns:
List[str]: Formatted strings for each related file
"""
return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(self._related_files.items())]
class RelatedFilesRepositoryManager:
"""
Context manager for RelatedFilesRepository.
This class provides a context manager interface for RelatedFilesRepository,
using the contextvars approach for thread safety.
Example:
with RelatedFilesRepositoryManager() as repo:
# Use the repository
file_id = repo.add_file("path/to/file.py")
all_files = repo.get_all()
"""
def __init__(self):
"""
Initialize the RelatedFilesRepositoryManager.
"""
pass
def __enter__(self) -> 'RelatedFilesRepository':
"""
Initialize the RelatedFilesRepository and return it.
Returns:
RelatedFilesRepository: The initialized repository
"""
repo = RelatedFilesRepository()
related_files_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
related_files_repo_var.set(None)
# Don't suppress exceptions
return False
def get_related_files_repository() -> RelatedFilesRepository:
"""
Get the current RelatedFilesRepository instance.
Returns:
RelatedFilesRepository: The current repository instance
Raises:
RuntimeError: If no repository is set in the current context
"""
repo = related_files_repo_var.get()
if repo is None:
raise RuntimeError(
"RelatedFilesRepository not initialized in current context. "
"Make sure to use RelatedFilesRepositoryManager."
)
return repo

View File

@ -173,8 +173,9 @@ If the user explicitly requested implementation, that means you should first per
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
AS THE RESEARCH AGENT, YOU MUST NOT WRITE OR MODIFY ANY FILES. IF FILE MODIFICATION OR IMPLEMENTATINO IS REQUIRED, CALL request_implementation.
AS THE RESEARCH AGENT, YOU MUST NOT WRITE OR MODIFY ANY FILES. IF FILE MODIFICATION OR IMPLEMENTATION IS REQUIRED, CALL request_implementation.
IF THE USER ASKED YOU TO UPDATE A FILE, JUST DO RESEARCH FIRST, EMIT YOUR RESEARCH NOTES, THEN CALL request_implementation.
CALL request_implementation ONLY ONCE! ONCE THE PLAN COMPLETES, YOU'RE DONE.
"""
)

View File

@ -16,6 +16,7 @@ from ra_aid.console.formatting import print_error
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.exceptions import AgentInterrupt
from ra_aid.model_formatters import format_key_facts_dict
@ -340,7 +341,7 @@ def request_task_implementation(task_spec: str) -> str:
)
# Get required parameters
related_files = list(_global_memory["related_files"].values())
related_files = list(get_related_files_repository().get_all().values())
try:
print_task_header(task_spec)

View File

@ -11,6 +11,7 @@ logger = logging.getLogger(__name__)
from ..database.repositories.key_fact_repository import get_key_fact_repository
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
from ..database.repositories.related_files_repository import get_related_files_repository
from ..database.repositories.research_note_repository import get_research_note_repository
from ..llm import initialize_expert_llm
from ..model_formatters import format_key_facts_dict
@ -154,7 +155,7 @@ def ask_expert(question: str) -> str:
global expert_context
# Get all content first
file_paths = list(_global_memory["related_files"].values())
file_paths = list(get_related_files_repository().get_all().values())
related_contents = read_related_files(file_paths)
# Get key snippets directly from repository and format using the formatter
try:

View File

@ -1,12 +1,8 @@
import os
from typing import Any, Dict, List, Optional
try:
import magic
except ImportError:
magic = None
from langchain_core.tools import tool
from ra_aid.utils.file_utils import is_binary_file
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
@ -44,10 +40,11 @@ console = Console()
# Import repositories using the get_* functions
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
# Import the related files repository
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
# Global memory store
_global_memory: Dict[str, Any] = {
"related_files": {}, # Dict[int, str] - ID to filepath mapping
"related_file_id_counter": 1, # Counter for generating unique file IDs
"agent_depth": 0,
"work_log": [], # List[WorkLogEntry] - Timestamped work events
}
@ -302,8 +299,8 @@ def get_related_files() -> List[str]:
Returns:
List of formatted strings in the format 'ID#X path/to/file.py'
"""
files = _global_memory["related_files"]
return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(files.items())]
repo = get_related_files_repository()
return repo.format_related_files()
@tool("emit_related_files")
@ -313,6 +310,7 @@ def emit_related_files(files: List[str]) -> str:
Args:
files: List of file paths to add
"""
repo = get_related_files_repository()
results = []
added_files = []
invalid_paths = []
@ -344,27 +342,20 @@ def emit_related_files(files: List[str]) -> str:
results.append(f"Skipped binary file: '{file}'")
continue
# Normalize the path
normalized_path = os.path.abspath(file)
# Add file to repository
file_id = repo.add_file(file)
# Check if normalized path already exists in values
existing_id = None
for fid, fpath in _global_memory["related_files"].items():
if fpath == normalized_path:
existing_id = fid
break
if file_id is not None:
# Check if it's a new file by comparing with previous results
is_new_file = True
for r in results:
if r.startswith(f"File ID #{file_id}:"):
is_new_file = False
break
if existing_id is not None:
# File exists, use existing ID
results.append(f"File ID #{existing_id}: {file}")
else:
# New file, assign new ID
file_id = _global_memory["related_file_id_counter"]
_global_memory["related_file_id_counter"] += 1
if is_new_file:
added_files.append((file_id, file)) # Keep original path for display
# Store normalized path with ID
_global_memory["related_files"][file_id] = normalized_path
added_files.append((file_id, file)) # Keep original path for display
results.append(f"File ID #{file_id}: {file}")
# Rich output - single consolidated panel for added files
@ -421,57 +412,7 @@ def log_work_event(event: str) -> str:
return f"Event logged: {event}"
def is_binary_file(filepath):
"""Check if a file is binary using magic library if available."""
# First check if file is empty
if os.path.getsize(filepath) == 0:
return False # Empty files are not binary
if magic:
try:
mime = magic.from_file(filepath, mime=True)
file_type = magic.from_file(filepath)
# If MIME type starts with 'text/', it's likely a text file
if mime.startswith("text/"):
return False
# Also consider 'application/x-python' and similar script types as text
if any(mime.startswith(prefix) for prefix in ['application/x-python', 'application/javascript']):
return False
# Check for common text file descriptors
text_indicators = ["text", "script", "xml", "json", "yaml", "markdown", "HTML"]
if any(indicator.lower() in file_type.lower() for indicator in text_indicators):
return False
# If none of the text indicators are present, assume it's binary
return True
except Exception:
return _is_binary_fallback(filepath)
else:
return _is_binary_fallback(filepath)
def _is_binary_fallback(filepath):
"""Fallback method to detect binary files without using magic."""
try:
# First check if file is empty
if os.path.getsize(filepath) == 0:
return False # Empty files are not binary
with open(filepath, "r", encoding="utf-8") as f:
chunk = f.read(1024)
# Check for null bytes which indicate binary content
if "\0" in chunk:
return True
# If we can read it as text without errors, it's probably not binary
return False
except UnicodeDecodeError:
# If we can't decode as UTF-8, it's likely binary
return True
def get_work_log() -> str:
@ -518,17 +459,18 @@ def reset_work_log() -> str:
@tool("deregister_related_files")
def deregister_related_files(file_ids: List[int]) -> str:
"""Delete multiple related files from global memory by their IDs.
"""Delete multiple related files by their IDs.
Silently skips any IDs that don't exist.
Args:
file_ids: List of file IDs to delete
"""
repo = get_related_files_repository()
results = []
for file_id in file_ids:
if file_id in _global_memory["related_files"]:
# Delete the file reference
deleted_file = _global_memory["related_files"].pop(file_id)
deleted_file = repo.remove_file(file_id)
if deleted_file:
success_msg = (
f"Successfully removed related file #{file_id}: {deleted_file}"
)

View File

@ -14,6 +14,7 @@ from ra_aid.models_params import DEFAULT_BASE_LATENCY, models_params
from ra_aid.proc.interactive import run_interactive_command
from ra_aid.text.processing import truncate_output
from ra_aid.tools.memory import _global_memory, log_work_event
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
console = Console()
logger = get_logger(__name__)
@ -91,13 +92,18 @@ def run_programming_task(
# Get combined list of files (explicit + related) with normalized paths
# and deduplicated using set operations
related_files_paths = []
try:
repo = get_related_files_repository()
related_files_paths = list(repo.get_all().values())
logger.debug("Retrieved related files from repository")
except RuntimeError as e:
# Repository not initialized
logger.warning(f"Failed to get related files repository: {e}")
files_to_use = list(
{os.path.abspath(f) for f in (files or [])}
| {
os.path.abspath(f)
for f in _global_memory["related_files"].values()
if "related_files" in _global_memory
}
| {os.path.abspath(f) for f in related_files_paths}
)
# Add config file if specified

5
ra_aid/utils/__init__.py Normal file
View File

@ -0,0 +1,5 @@
"""Utility functions for the ra-aid project."""
from .file_utils import is_binary_file
__all__ = ["is_binary_file"]

View File

@ -0,0 +1,61 @@
"""Utility functions for file operations."""
import os
try:
import magic
except ImportError:
magic = None
def is_binary_file(filepath):
"""Check if a file is binary using magic library if available."""
# First check if file is empty
if os.path.getsize(filepath) == 0:
return False # Empty files are not binary
if magic:
try:
mime = magic.from_file(filepath, mime=True)
file_type = magic.from_file(filepath)
# If MIME type starts with 'text/', it's likely a text file
if mime.startswith("text/"):
return False
# Also consider 'application/x-python' and similar script types as text
if any(mime.startswith(prefix) for prefix in ['application/x-python', 'application/javascript']):
return False
# Check for common text file descriptors
text_indicators = ["text", "script", "xml", "json", "yaml", "markdown", "HTML"]
if any(indicator.lower() in file_type.lower() for indicator in text_indicators):
return False
# If none of the text indicators are present, assume it's binary
return True
except Exception:
return _is_binary_fallback(filepath)
else:
return _is_binary_fallback(filepath)
def _is_binary_fallback(filepath):
"""Fallback method to detect binary files without using magic."""
try:
# First check if file is empty
if os.path.getsize(filepath) == 0:
return False # Empty files are not binary
with open(filepath, "r", encoding="utf-8") as f:
chunk = f.read(1024)
# Check for null bytes which indicate binary content
if "\0" in chunk:
return True
# If we can read it as text without errors, it's probably not binary
return False
except UnicodeDecodeError:
# If we can't decode as UTF-8, it's likely binary
return True

View File

@ -1,6 +1,7 @@
"""Unit tests for __main__.py argument parsing."""
import pytest
from unittest.mock import patch, MagicMock
from ra_aid.__main__ import parse_arguments
from ra_aid.config import DEFAULT_RECURSION_LIMIT
@ -12,8 +13,6 @@ def mock_dependencies(monkeypatch):
"""Mock all dependencies needed for main()."""
# Initialize global memory with necessary keys to prevent KeyError
_global_memory.clear()
_global_memory["related_files"] = {}
_global_memory["related_file_id_counter"] = 1
_global_memory["agent_depth"] = 0
_global_memory["work_log"] = []
_global_memory["config"] = {}
@ -37,6 +36,28 @@ def mock_dependencies(monkeypatch):
monkeypatch.setattr("ra_aid.__main__.initialize_llm", mock_config_update)
@pytest.fixture(autouse=True)
def mock_related_files_repository():
"""Mock the RelatedFilesRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.related_files_repository.related_files_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate stored files
related_files = {}
# Setup get_all method to return the files dict
mock_repo.get_all.return_value = related_files
# Setup format_related_files method
mock_repo.format_related_files.return_value = [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(related_files.items())]
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
def test_recursion_limit_in_global_config(mock_dependencies):
"""Test that recursion limit is correctly set in global config."""
import sys
@ -122,8 +143,6 @@ def test_temperature_validation(mock_dependencies):
# Reset global memory for clean test
_global_memory.clear()
_global_memory["related_files"] = {}
_global_memory["related_file_id_counter"] = 1
_global_memory["agent_depth"] = 0
_global_memory["work_log"] = []
_global_memory["config"] = {}

View File

@ -1,10 +1,57 @@
import pytest
from unittest.mock import patch, MagicMock
from ra_aid.tools.programmer import (
get_aider_executable,
parse_aider_flags,
run_programming_task,
)
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
@pytest.fixture(autouse=True)
def mock_related_files_repository():
"""Mock the RelatedFilesRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.related_files_repository.related_files_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate stored files
related_files = {}
# Setup get_all method to return the files dict
mock_repo.get_all.return_value = related_files
# Setup format_related_files method
mock_repo.format_related_files.return_value = [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(related_files.items())]
# Setup add_file method
def mock_add_file(filepath):
normalized_path = os.path.abspath(filepath)
# Check if path already exists
for file_id, path in related_files.items():
if path == normalized_path:
return file_id
# Add new file
file_id = len(related_files) + 1
related_files[file_id] = normalized_path
return file_id
mock_repo.add_file.side_effect = mock_add_file
# Setup remove_file method
def mock_remove_file(file_id):
if file_id in related_files:
return related_files.pop(file_id)
return None
mock_repo.remove_file.side_effect = mock_remove_file
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
# Also patch the get_related_files_repository function
with patch('ra_aid.tools.programmer.get_related_files_repository', return_value=mock_repo):
yield mock_repo
# Test cases for parse_aider_flags function
test_cases = [
@ -78,11 +125,11 @@ def test_parse_aider_flags(input_flags, expected, description):
assert result == expected, f"Failed test case: {description}"
def test_aider_config_flag(mocker):
def test_aider_config_flag(mocker, mock_related_files_repository):
"""Test that aider config flag is properly included in the command when specified."""
# Mock config in global memory but not related files (using repository now)
mock_memory = {
"config": {"aider_config": "/path/to/config.yml"},
"related_files": {},
}
mocker.patch("ra_aid.tools.programmer._global_memory", mock_memory)
@ -99,15 +146,15 @@ def test_aider_config_flag(mocker):
assert args[config_index + 1] == "/path/to/config.yml"
def test_path_normalization_and_deduplication(mocker, tmp_path):
def test_path_normalization_and_deduplication(mocker, tmp_path, mock_related_files_repository):
"""Test path normalization and deduplication in run_programming_task."""
# Create a temporary test file
test_file = tmp_path / "test.py"
test_file.write_text("")
new_file = tmp_path / "new.py"
# Mock dependencies
mocker.patch("ra_aid.tools.programmer._global_memory", {"related_files": {}})
# Mock dependencies - only need to mock config part of global memory now
mocker.patch("ra_aid.tools.programmer._global_memory", {"config": {}})
mocker.patch(
"ra_aid.tools.programmer.get_aider_executable", return_value="/path/to/aider"
)

View File

@ -2,6 +2,7 @@
import pytest
from unittest.mock import patch, MagicMock
import os
from ra_aid.tools.agent import (
request_research,
@ -10,21 +11,39 @@ from ra_aid.tools.agent import (
request_research_and_implementation,
)
from ra_aid.tools.memory import _global_memory
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
@pytest.fixture
def reset_memory():
"""Reset global memory before each test"""
_global_memory["related_files"] = {}
_global_memory["related_file_id_counter"] = 0
_global_memory["work_log"] = []
yield
# Clean up after test
_global_memory["related_files"] = {}
_global_memory["related_file_id_counter"] = 0
_global_memory["work_log"] = []
@pytest.fixture(autouse=True)
def mock_related_files_repository():
"""Mock the RelatedFilesRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.related_files_repository.related_files_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate stored files
related_files = {}
# Setup get_all method to return the files dict
mock_repo.get_all.return_value = related_files
# Setup format_related_files method
mock_repo.format_related_files.return_value = [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(related_files.items())]
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
@pytest.fixture
def mock_functions():
"""Mock functions used in agent.py"""

View File

@ -17,11 +17,11 @@ from ra_aid.tools.memory import (
get_work_log,
log_work_event,
reset_work_log,
is_binary_file,
_is_binary_fallback,
)
from ra_aid.utils.file_utils import is_binary_file, _is_binary_fallback
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
from ra_aid.database.connection import DatabaseManager
from ra_aid.database.models import KeyFact
@ -29,13 +29,9 @@ from ra_aid.database.models import KeyFact
@pytest.fixture
def reset_memory():
"""Reset global memory before each test"""
_global_memory["related_files"] = {}
_global_memory["related_file_id_counter"] = 0
_global_memory["work_log"] = []
yield
# Clean up after test
_global_memory["related_files"] = {}
_global_memory["related_file_id_counter"] = 0
_global_memory["work_log"] = []
@ -165,6 +161,66 @@ def mock_key_snippet_repository():
yield memory_mock_repo
@pytest.fixture(autouse=True)
def mock_related_files_repository():
"""Mock the RelatedFilesRepository to avoid database operations during tests"""
with patch('ra_aid.tools.memory.get_related_files_repository') as mock_repo:
# Setup the mock repository to behave like the original, but using memory
related_files = {} # Local in-memory storage
id_counter = 0
# Mock add_file method
def mock_add_file(filepath):
nonlocal id_counter
# Check if normalized path already exists in values
normalized_path = os.path.abspath(filepath)
for file_id, path in related_files.items():
if path == normalized_path:
return file_id
# First check if path exists
if not os.path.exists(filepath):
return None
# Then check if it's a directory
if os.path.isdir(filepath):
return None
# Validate it's a regular file
if not os.path.isfile(filepath):
return None
# Check if it's a binary file (don't actually check in tests)
# We'll mock is_binary_file separately when needed
# Add new file
file_id = id_counter
id_counter += 1
related_files[file_id] = normalized_path
return file_id
mock_repo.return_value.add_file.side_effect = mock_add_file
# Mock get_all method
def mock_get_all():
return related_files.copy()
mock_repo.return_value.get_all.side_effect = mock_get_all
# Mock remove_file method
def mock_remove_file(file_id):
if file_id in related_files:
return related_files.pop(file_id)
return None
mock_repo.return_value.remove_file.side_effect = mock_remove_file
# Mock format_related_files method
def mock_format_related_files():
return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(related_files.items())]
mock_repo.return_value.format_related_files.side_effect = mock_format_related_files
yield mock_repo
def test_emit_key_facts_single_fact(reset_memory, mock_repository):
"""Test emitting a single key fact using emit_key_facts"""
# Test with single fact
@ -177,9 +233,6 @@ def test_emit_key_facts_single_fact(reset_memory, mock_repository):
def test_get_memory_value_other_types(reset_memory):
"""Test get_memory_value remains compatible with other memory types"""
# Test with empty list
assert get_memory_value("plans") == ""
# Test with non-existent key
assert get_memory_value("nonexistent") == ""
@ -263,7 +316,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
"""Test that emit_key_facts triggers the cleaner agent when there are more than 30 facts"""
# Setup mock repository to return more than 30 facts
facts = []
for i in range(31):
for i in range(51):
facts.append(MagicMock(id=i, content=f"Test fact {i}", human_input_id=None))
# Mock the get_all method to return more than 30 facts
@ -407,7 +460,7 @@ def test_delete_key_snippets_empty(mock_log_work_event, reset_memory, mock_key_s
mock_key_snippet_repository.return_value.delete.assert_not_called()
def test_emit_related_files_basic(reset_memory, tmp_path):
def test_emit_related_files_basic(reset_memory, mock_related_files_repository, tmp_path):
"""Test basic adding of files with ID tracking"""
# Create test files
test_file = tmp_path / "test.py"
@ -420,23 +473,26 @@ def test_emit_related_files_basic(reset_memory, tmp_path):
# Test adding single file
result = emit_related_files.invoke({"files": [str(test_file)]})
assert result == "Files noted."
assert _global_memory["related_files"][0] == str(test_file)
# Verify file was added using the repository
mock_related_files_repository.return_value.add_file.assert_called_with(str(test_file))
# Test adding multiple files
result = emit_related_files.invoke({"files": [str(main_file), str(utils_file)]})
assert result == "Files noted."
# Verify both files exist in related_files
values = list(_global_memory["related_files"].values())
assert str(main_file) in values
assert str(utils_file) in values
# Verify both files were added
mock_related_files_repository.return_value.add_file.assert_any_call(str(main_file))
mock_related_files_repository.return_value.add_file.assert_any_call(str(utils_file))
def test_get_related_files_empty(reset_memory):
def test_get_related_files_empty(reset_memory, mock_related_files_repository):
"""Test getting related files when none added"""
# Mock empty format_related_files result
mock_related_files_repository.return_value.format_related_files.return_value = []
assert get_related_files() == []
mock_related_files_repository.return_value.format_related_files.assert_called_once()
def test_emit_related_files_duplicates(reset_memory, tmp_path):
def test_emit_related_files_duplicates(reset_memory, mock_related_files_repository, tmp_path):
"""Test that duplicate files return existing IDs with proper formatting"""
# Create test files
test_file = tmp_path / "test.py"
@ -446,46 +502,34 @@ def test_emit_related_files_duplicates(reset_memory, tmp_path):
new_file = tmp_path / "new.py"
new_file.write_text("# New file")
# Mock add_file to return consistent IDs
def mock_add_file(filepath):
if "test.py" in filepath:
return 0
elif "main.py" in filepath:
return 1
elif "new.py" in filepath:
return 2
return None
mock_related_files_repository.return_value.add_file.side_effect = mock_add_file
# Add initial files
result1 = emit_related_files.invoke({"files": [str(test_file), str(main_file)]})
assert result1 == "Files noted."
_first_id = 0 # ID of test.py
# Try adding duplicates
result2 = emit_related_files.invoke({"files": [str(test_file)]})
assert result2 == "Files noted."
assert len(_global_memory["related_files"]) == 2 # Count should not increase
# Try mix of new and duplicate files
result = emit_related_files.invoke({"files": [str(test_file), str(new_file)]})
assert result == "Files noted."
assert len(_global_memory["related_files"]) == 3
# Verify calls to add_file - should be called for each file (even duplicates)
assert mock_related_files_repository.return_value.add_file.call_count == 5
def test_related_files_id_tracking(reset_memory, tmp_path):
"""Test ID assignment and counter functionality for related files"""
# Create test files
file1 = tmp_path / "file1.py"
file1.write_text("# File 1")
file2 = tmp_path / "file2.py"
file2.write_text("# File 2")
# Add first file
result = emit_related_files.invoke({"files": [str(file1)]})
assert result == "Files noted."
assert _global_memory["related_file_id_counter"] == 1
# Add second file
result = emit_related_files.invoke({"files": [str(file2)]})
assert result == "Files noted."
assert _global_memory["related_file_id_counter"] == 2
# Verify all files stored correctly
assert _global_memory["related_files"][0] == str(file1)
assert _global_memory["related_files"][1] == str(file2)
def test_deregister_related_files(reset_memory, tmp_path):
def test_deregister_related_files(reset_memory, mock_related_files_repository, tmp_path):
"""Test deleting related files"""
# Create test files
file1 = tmp_path / "file1.py"
@ -495,276 +539,108 @@ def test_deregister_related_files(reset_memory, tmp_path):
file3 = tmp_path / "file3.py"
file3.write_text("# File 3")
# Add test files
emit_related_files.invoke({"files": [str(file1), str(file2), str(file3)]})
# Mock remove_file to return file paths for existing IDs
def mock_remove_file(file_id):
if file_id == 0:
return str(file1)
elif file_id == 1:
return str(file2)
elif file_id == 2:
return str(file3)
return None
mock_related_files_repository.return_value.remove_file.side_effect = mock_remove_file
# Delete middle file
result = deregister_related_files.invoke({"file_ids": [1]})
assert result == "Files noted."
assert 1 not in _global_memory["related_files"]
assert len(_global_memory["related_files"]) == 2
mock_related_files_repository.return_value.remove_file.assert_called_with(1)
# Delete multiple files including non-existent ID
result = deregister_related_files.invoke({"file_ids": [0, 2, 999]})
assert result == "Files noted."
assert len(_global_memory["related_files"]) == 0
# Counter should remain unchanged after deletions
assert _global_memory["related_file_id_counter"] == 3
mock_related_files_repository.return_value.remove_file.assert_any_call(0)
mock_related_files_repository.return_value.remove_file.assert_any_call(2)
mock_related_files_repository.return_value.remove_file.assert_any_call(999)
def test_related_files_duplicates(reset_memory, tmp_path):
"""Test duplicate file handling returns same ID"""
# Create test file
test_file = tmp_path / "test.py"
test_file.write_text("# Test file")
# Add initial file
result1 = emit_related_files.invoke({"files": [str(test_file)]})
assert result1 == "Files noted."
# Add same file again
result2 = emit_related_files.invoke({"files": [str(test_file)]})
assert result2 == "Files noted."
# Verify only one entry exists
assert len(_global_memory["related_files"]) == 1
assert _global_memory["related_file_id_counter"] == 1
def test_emit_related_files_with_directory(reset_memory, tmp_path):
"""Test that directories and non-existent paths are rejected while valid files are added"""
# Create test directory and file
test_dir = tmp_path / "test_dir"
test_dir.mkdir()
test_file = tmp_path / "test_file.txt"
test_file.write_text("test content")
nonexistent = tmp_path / "does_not_exist.txt"
# Try to emit directory, nonexistent path, and valid file
result = emit_related_files.invoke(
{"files": [str(test_dir), str(nonexistent), str(test_file)]}
)
# Verify result is the standard message
assert result == "Files noted."
# Verify directory and nonexistent not added but valid file was
assert len(_global_memory["related_files"]) == 1
values = list(_global_memory["related_files"].values())
assert str(test_file) in values
assert str(test_dir) not in values
assert str(nonexistent) not in values
def test_related_files_formatting(reset_memory, tmp_path):
"""Test related files output string formatting"""
# Create test files
file1 = tmp_path / "file1.py"
file1.write_text("# File 1")
file2 = tmp_path / "file2.py"
file2.write_text("# File 2")
# Add some files
emit_related_files.invoke({"files": [str(file1), str(file2)]})
# Get formatted output
output = get_memory_value("related_files")
# Expect just the IDs on separate lines
expected = "0\n1"
assert output == expected
# Test empty case
_global_memory["related_files"] = {}
assert get_memory_value("related_files") == ""
def test_emit_related_files_path_normalization(reset_memory, tmp_path):
"""Test that emit_related_files fails to detect duplicates with non-normalized paths"""
def test_emit_related_files_path_normalization(reset_memory, mock_related_files_repository, tmp_path):
"""Test that emit_related_files normalization works correctly"""
# Create a test file
test_file = tmp_path / "file.txt"
test_file.write_text("test content")
# Change to the temp directory so relative paths work
import os
original_dir = os.getcwd()
os.chdir(tmp_path)
try:
# Add file with absolute path
# Set up mock to test path normalization
def mock_add_file(filepath):
# The repository normalizes paths before comparing
# This mock simulates that behavior
normalized_path = os.path.abspath(filepath)
if normalized_path == os.path.abspath("file.txt"):
return 0
return None
mock_related_files_repository.return_value.add_file.side_effect = mock_add_file
# Add file with relative path
result1 = emit_related_files.invoke({"files": ["file.txt"]})
assert result1 == "Files noted."
# Add same file with relative path - should get same ID due to path normalization
# Add same file with different relative path - should get same ID
result2 = emit_related_files.invoke({"files": ["./file.txt"]})
assert result2 == "Files noted."
# Verify only one normalized path entry exists
assert len(_global_memory["related_files"]) == 1
assert os.path.abspath("file.txt") in _global_memory["related_files"].values()
# Verify both calls to add_file were made
assert mock_related_files_repository.return_value.add_file.call_count == 2
finally:
# Restore original directory
os.chdir(original_dir)
@patch('ra_aid.agents.key_snippets_gc_agent.log_work_event')
def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_snippet_repository):
"""Integration test for key snippets functionality"""
# Create test files
import tempfile
import os
with tempfile.TemporaryDirectory() as tmp_path:
file1 = os.path.join(tmp_path, "file1.py")
with open(file1, 'w') as f:
f.write("def func1():\n pass")
file2 = os.path.join(tmp_path, "file2.py")
with open(file2, 'w') as f:
f.write("def func2():\n return True")
file3 = os.path.join(tmp_path, "file3.py")
with open(file3, 'w') as f:
f.write("class TestClass:\n pass")
# Initial snippets to add
snippets = [
{
"filepath": file1,
"line_number": 10,
"snippet": "def func1():\n pass",
"description": "First function",
},
{
"filepath": file2,
"line_number": 20,
"snippet": "def func2():\n return True",
"description": "Second function",
},
{
"filepath": file3,
"line_number": 30,
"snippet": "class TestClass:\n pass",
"description": "Test class",
},
]
# Add all snippets one by one
for i, snippet in enumerate(snippets):
result = emit_key_snippet.invoke({"snippet_info": snippet})
assert result == f"Snippet #{i} stored."
# Reset mock to clear call history
mock_key_snippet_repository.reset_mock()
# Delete some but not all snippets (0 and 2)
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
result = delete_key_snippets.invoke({"snippet_ids": [0, 2]})
assert result == "Snippets deleted."
# Reset mock again
mock_key_snippet_repository.reset_mock()
# Add new snippet
file4 = os.path.join(tmp_path, "file4.py")
with open(file4, 'w') as f:
f.write("def func4():\n return False")
new_snippet = {
"filepath": file4,
"line_number": 40,
"snippet": "def func4():\n return False",
"description": "Fourth function",
}
result = emit_key_snippet.invoke({"snippet_info": new_snippet})
assert result == "Snippet #3 stored."
# Verify create was called with correct params
mock_key_snippet_repository.return_value.create.assert_called_with(
filepath=file4,
line_number=40,
snippet="def func4():\n return False",
description="Fourth function",
human_input_id=ANY
)
# Reset mock again
mock_key_snippet_repository.reset_mock()
# Delete remaining snippets
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
result = delete_key_snippets.invoke({"snippet_ids": [1, 3]})
assert result == "Snippets deleted."
def test_emit_related_files_binary_filtering(reset_memory, monkeypatch):
@patch('ra_aid.tools.memory.is_binary_file')
def test_emit_related_files_binary_filtering(mock_is_binary, reset_memory, mock_related_files_repository, tmp_path):
"""Test that binary files are filtered out when adding related files"""
import tempfile
import os
# Create test files
text_file1 = tmp_path / "text1.txt"
text_file1.write_text("Text file 1 content")
text_file2 = tmp_path / "text2.txt"
text_file2.write_text("Text file 2 content")
binary_file1 = tmp_path / "binary1.bin"
binary_file1.write_text("Binary file 1 content")
binary_file2 = tmp_path / "binary2.bin"
binary_file2.write_text("Binary file 2 content")
with tempfile.TemporaryDirectory() as tmp_path:
# Create test text files
text_file1 = os.path.join(tmp_path, "text1.txt")
with open(text_file1, 'w') as f:
f.write("Text file 1 content")
# Mock is_binary_file to identify our "binary" files
def mock_binary_check(filepath):
return ".bin" in str(filepath)
mock_is_binary.side_effect = mock_binary_check
text_file2 = os.path.join(tmp_path, "text2.txt")
with open(text_file2, 'w') as f:
f.write("Text file 2 content")
# Call emit_related_files with mix of text and binary files
result = emit_related_files.invoke({
"files": [
str(text_file1),
str(binary_file1),
str(text_file2),
str(binary_file2),
]
})
# Create test "binary" files
binary_file1 = os.path.join(tmp_path, "binary1.bin")
with open(binary_file1, 'w') as f:
f.write("Binary file 1 content")
# Verify the result message
assert "Files noted." in result
assert "Binary files skipped:" in result
binary_file2 = os.path.join(tmp_path, "binary2.bin")
with open(binary_file2, 'w') as f:
f.write("Binary file 2 content")
# Mock the is_binary_file function to identify our "binary" files
def mock_is_binary_file(filepath):
return ".bin" in str(filepath)
# Apply the mock
import ra_aid.tools.memory
monkeypatch.setattr(ra_aid.tools.memory, "is_binary_file", mock_is_binary_file)
# Call emit_related_files with mix of text and binary files
result = emit_related_files.invoke(
{
"files": [
text_file1,
binary_file1,
text_file2,
binary_file2,
]
}
)
# Verify the result message mentions skipped binary files
assert "Files noted." in result
assert "Binary files skipped:" in result
assert binary_file1 in result
assert binary_file2 in result
# Verify only text files were added to related_files
assert len(_global_memory["related_files"]) == 2
file_values = list(_global_memory["related_files"].values())
assert text_file1 in file_values
assert text_file2 in file_values
assert binary_file1 not in file_values
assert binary_file2 not in file_values
# Verify counter is correct (only incremented for text files)
assert _global_memory["related_file_id_counter"] == 2
# Verify repository calls - should only call add_file for text files
# Binary files should be filtered out before reaching the repository
assert mock_related_files_repository.return_value.add_file.call_count == 2
mock_related_files_repository.return_value.add_file.assert_any_call(str(text_file1))
mock_related_files_repository.return_value.add_file.assert_any_call(str(text_file2))
def test_is_binary_file_with_ascii(reset_memory, monkeypatch):
def test_is_binary_file_with_ascii():
"""Test that ASCII files are correctly identified as text files"""
import os
import tempfile
import ra_aid.tools.memory
# Create a test ASCII file
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
@ -772,32 +648,22 @@ def test_is_binary_file_with_ascii(reset_memory, monkeypatch):
ascii_file_path = f.name
try:
# Test with magic library if available
if ra_aid.tools.memory.magic:
# Test real implementation with ASCII file
is_binary = ra_aid.tools.memory.is_binary_file(ascii_file_path)
assert not is_binary, "ASCII file should not be identified as binary"
# Test real implementation with ASCII file
is_binary = is_binary_file(ascii_file_path)
assert not is_binary, "ASCII file should not be identified as binary"
# Test fallback implementation
# Mock magic to be None to force fallback implementation
monkeypatch.setattr(ra_aid.tools.memory, "magic", None)
# Test fallback with ASCII file
is_binary = ra_aid.tools.memory.is_binary_file(ascii_file_path)
assert (
not is_binary
), "ASCII file should not be identified as binary with fallback method"
is_binary_fallback = _is_binary_fallback(ascii_file_path)
assert not is_binary_fallback, "ASCII file should not be identified as binary with fallback method"
finally:
# Clean up
if os.path.exists(ascii_file_path):
os.unlink(ascii_file_path)
def test_is_binary_file_with_null_bytes(reset_memory, monkeypatch):
def test_is_binary_file_with_null_bytes():
"""Test that files with null bytes are correctly identified as binary"""
import os
import tempfile
import ra_aid.tools.memory
# Create a file with null bytes (binary content)
binary_file = tempfile.NamedTemporaryFile(delete=False)
@ -805,21 +671,13 @@ def test_is_binary_file_with_null_bytes(reset_memory, monkeypatch):
binary_file.close()
try:
# Test with magic library if available
if ra_aid.tools.memory.magic:
# Test real implementation with binary file
is_binary = ra_aid.tools.memory.is_binary_file(binary_file.name)
assert is_binary, "File with null bytes should be identified as binary"
# Test real implementation with binary file
is_binary = is_binary_file(binary_file.name)
assert is_binary, "File with null bytes should be identified as binary"
# Test fallback implementation
# Mock magic to be None to force fallback implementation
monkeypatch.setattr(ra_aid.tools.memory, "magic", None)
# Test fallback with binary file
is_binary = ra_aid.tools.memory.is_binary_file(binary_file.name)
assert (
is_binary
), "File with null bytes should be identified as binary with fallback method"
is_binary_fallback = _is_binary_fallback(binary_file.name)
assert is_binary_fallback, "File with null bytes should be identified as binary with fallback method"
finally:
# Clean up
if os.path.exists(binary_file.name):
@ -850,7 +708,7 @@ def test_python_file_detection():
import magic
if magic:
# Only run this part of the test if magic is available
with patch('ra_aid.tools.memory.magic') as mock_magic:
with patch('ra_aid.utils.file_utils.magic') as mock_magic:
# Mock magic to simulate the behavior that causes the issue
mock_magic.from_file.side_effect = [
"text/x-python", # First call with mime=True