use repo pattern
This commit is contained in:
parent
0afed55809
commit
600bf355d9
|
|
@ -52,6 +52,9 @@ from ra_aid.database.repositories.human_input_repository import (
|
||||||
from ra_aid.database.repositories.research_note_repository import (
|
from ra_aid.database.repositories.research_note_repository import (
|
||||||
ResearchNoteRepositoryManager, get_research_note_repository
|
ResearchNoteRepositoryManager, get_research_note_repository
|
||||||
)
|
)
|
||||||
|
from ra_aid.database.repositories.related_files_repository import (
|
||||||
|
RelatedFilesRepositoryManager
|
||||||
|
)
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
from ra_aid.model_formatters import format_key_facts_dict
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
from ra_aid.console.output import cpm
|
from ra_aid.console.output import cpm
|
||||||
|
|
@ -402,12 +405,14 @@ def main():
|
||||||
with KeyFactRepositoryManager(db) as key_fact_repo, \
|
with KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||||
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||||
ResearchNoteRepositoryManager(db) as research_note_repo:
|
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||||
|
RelatedFilesRepositoryManager() as related_files_repo:
|
||||||
# This initializes all repositories and makes them available via their respective get methods
|
# This initializes all repositories and makes them available via their respective get methods
|
||||||
logger.debug("Initialized KeyFactRepository")
|
logger.debug("Initialized KeyFactRepository")
|
||||||
logger.debug("Initialized KeySnippetRepository")
|
logger.debug("Initialized KeySnippetRepository")
|
||||||
logger.debug("Initialized HumanInputRepository")
|
logger.debug("Initialized HumanInputRepository")
|
||||||
logger.debug("Initialized ResearchNoteRepository")
|
logger.debug("Initialized ResearchNoteRepository")
|
||||||
|
logger.debug("Initialized RelatedFilesRepository")
|
||||||
|
|
||||||
# Check dependencies before proceeding
|
# Check dependencies before proceeding
|
||||||
check_dependencies()
|
check_dependencies()
|
||||||
|
|
|
||||||
|
|
@ -397,7 +397,7 @@ def run_research_agent(
|
||||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
key_facts = ""
|
key_facts = ""
|
||||||
code_snippets = _global_memory.get("code_snippets", "")
|
code_snippets = _global_memory.get("code_snippets", "")
|
||||||
related_files = _global_memory.get("related_files", "")
|
related_files = get_related_files()
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
working_directory = os.getcwd()
|
working_directory = os.getcwd()
|
||||||
|
|
@ -552,7 +552,7 @@ def run_web_research_agent(
|
||||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
key_facts = ""
|
key_facts = ""
|
||||||
code_snippets = _global_memory.get("code_snippets", "")
|
code_snippets = _global_memory.get("code_snippets", "")
|
||||||
related_files = _global_memory.get("related_files", "")
|
related_files = get_related_files()
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
working_directory = os.getcwd()
|
working_directory = os.getcwd()
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,11 @@ from ra_aid.database.repositories.key_snippet_repository import (
|
||||||
KeySnippetRepositoryManager,
|
KeySnippetRepositoryManager,
|
||||||
get_key_snippet_repository
|
get_key_snippet_repository
|
||||||
)
|
)
|
||||||
|
from ra_aid.database.repositories.related_files_repository import (
|
||||||
|
RelatedFilesRepository,
|
||||||
|
RelatedFilesRepositoryManager,
|
||||||
|
get_related_files_repository
|
||||||
|
)
|
||||||
from ra_aid.database.repositories.research_note_repository import (
|
from ra_aid.database.repositories.research_note_repository import (
|
||||||
ResearchNoteRepository,
|
ResearchNoteRepository,
|
||||||
ResearchNoteRepositoryManager,
|
ResearchNoteRepositoryManager,
|
||||||
|
|
@ -36,6 +41,9 @@ __all__ = [
|
||||||
'KeySnippetRepository',
|
'KeySnippetRepository',
|
||||||
'KeySnippetRepositoryManager',
|
'KeySnippetRepositoryManager',
|
||||||
'get_key_snippet_repository',
|
'get_key_snippet_repository',
|
||||||
|
'RelatedFilesRepository',
|
||||||
|
'RelatedFilesRepositoryManager',
|
||||||
|
'get_related_files_repository',
|
||||||
'ResearchNoteRepository',
|
'ResearchNoteRepository',
|
||||||
'ResearchNoteRepositoryManager',
|
'ResearchNoteRepositoryManager',
|
||||||
'get_research_note_repository',
|
'get_research_note_repository',
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,169 @@
|
||||||
|
import contextvars
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
# Import is_binary_file from memory.py
|
||||||
|
from ra_aid.utils.file_utils import is_binary_file
|
||||||
|
|
||||||
|
# Create contextvar to hold the RelatedFilesRepository instance
|
||||||
|
related_files_repo_var = contextvars.ContextVar("related_files_repo", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class RelatedFilesRepository:
|
||||||
|
"""
|
||||||
|
Repository for managing related files in memory.
|
||||||
|
|
||||||
|
This class provides methods to add, remove, and retrieve related files.
|
||||||
|
It does not require database models and operates entirely in memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the RelatedFilesRepository.
|
||||||
|
"""
|
||||||
|
self._related_files: Dict[int, str] = {}
|
||||||
|
self._id_counter: int = 1
|
||||||
|
|
||||||
|
def get_all(self) -> Dict[int, str]:
|
||||||
|
"""
|
||||||
|
Get all related files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[int, str]: Dictionary mapping file IDs to file paths
|
||||||
|
"""
|
||||||
|
return self._related_files.copy()
|
||||||
|
|
||||||
|
def add_file(self, filepath: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Add a file to the repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath: Path to the file to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: The ID assigned to the file, or None if the file could not be added
|
||||||
|
"""
|
||||||
|
# First check if path exists
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Then check if it's a directory
|
||||||
|
if os.path.isdir(filepath):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Validate it's a regular file
|
||||||
|
if not os.path.isfile(filepath):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if it's a binary file
|
||||||
|
if is_binary_file(filepath):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Normalize the path
|
||||||
|
normalized_path = os.path.abspath(filepath)
|
||||||
|
|
||||||
|
# Check if normalized path already exists in values
|
||||||
|
for file_id, path in self._related_files.items():
|
||||||
|
if path == normalized_path:
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
# Add new file
|
||||||
|
file_id = self._id_counter
|
||||||
|
self._id_counter += 1
|
||||||
|
self._related_files[file_id] = normalized_path
|
||||||
|
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
def remove_file(self, file_id: int) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Remove a file from the repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: ID of the file to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The path of the removed file, or None if the file ID was not found
|
||||||
|
"""
|
||||||
|
if file_id in self._related_files:
|
||||||
|
return self._related_files.pop(file_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def format_related_files(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Format related files as 'ID#X path/to/file'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: Formatted strings for each related file
|
||||||
|
"""
|
||||||
|
return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(self._related_files.items())]
|
||||||
|
|
||||||
|
|
||||||
|
class RelatedFilesRepositoryManager:
|
||||||
|
"""
|
||||||
|
Context manager for RelatedFilesRepository.
|
||||||
|
|
||||||
|
This class provides a context manager interface for RelatedFilesRepository,
|
||||||
|
using the contextvars approach for thread safety.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with RelatedFilesRepositoryManager() as repo:
|
||||||
|
# Use the repository
|
||||||
|
file_id = repo.add_file("path/to/file.py")
|
||||||
|
all_files = repo.get_all()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the RelatedFilesRepositoryManager.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self) -> 'RelatedFilesRepository':
|
||||||
|
"""
|
||||||
|
Initialize the RelatedFilesRepository and return it.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RelatedFilesRepository: The initialized repository
|
||||||
|
"""
|
||||||
|
repo = RelatedFilesRepository()
|
||||||
|
related_files_repo_var.set(repo)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type],
|
||||||
|
exc_val: Optional[Exception],
|
||||||
|
exc_tb: Optional[object],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Reset the repository when exiting the context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_type: The exception type if an exception was raised
|
||||||
|
exc_val: The exception value if an exception was raised
|
||||||
|
exc_tb: The traceback if an exception was raised
|
||||||
|
"""
|
||||||
|
# Reset the contextvar to None
|
||||||
|
related_files_repo_var.set(None)
|
||||||
|
|
||||||
|
# Don't suppress exceptions
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_related_files_repository() -> RelatedFilesRepository:
|
||||||
|
"""
|
||||||
|
Get the current RelatedFilesRepository instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RelatedFilesRepository: The current repository instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no repository is set in the current context
|
||||||
|
"""
|
||||||
|
repo = related_files_repo_var.get()
|
||||||
|
if repo is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"RelatedFilesRepository not initialized in current context. "
|
||||||
|
"Make sure to use RelatedFilesRepositoryManager."
|
||||||
|
)
|
||||||
|
return repo
|
||||||
|
|
@ -173,8 +173,9 @@ If the user explicitly requested implementation, that means you should first per
|
||||||
|
|
||||||
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
||||||
|
|
||||||
AS THE RESEARCH AGENT, YOU MUST NOT WRITE OR MODIFY ANY FILES. IF FILE MODIFICATION OR IMPLEMENTATINO IS REQUIRED, CALL request_implementation.
|
AS THE RESEARCH AGENT, YOU MUST NOT WRITE OR MODIFY ANY FILES. IF FILE MODIFICATION OR IMPLEMENTATION IS REQUIRED, CALL request_implementation.
|
||||||
IF THE USER ASKED YOU TO UPDATE A FILE, JUST DO RESEARCH FIRST, EMIT YOUR RESEARCH NOTES, THEN CALL request_implementation.
|
IF THE USER ASKED YOU TO UPDATE A FILE, JUST DO RESEARCH FIRST, EMIT YOUR RESEARCH NOTES, THEN CALL request_implementation.
|
||||||
|
CALL request_implementation ONLY ONCE! ONCE THE PLAN COMPLETES, YOU'RE DONE.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from ra_aid.console.formatting import print_error
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||||
from ra_aid.exceptions import AgentInterrupt
|
from ra_aid.exceptions import AgentInterrupt
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
from ra_aid.model_formatters import format_key_facts_dict
|
||||||
|
|
@ -340,7 +341,7 @@ def request_task_implementation(task_spec: str) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get required parameters
|
# Get required parameters
|
||||||
related_files = list(_global_memory["related_files"].values())
|
related_files = list(get_related_files_repository().get_all().values())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print_task_header(task_spec)
|
print_task_header(task_spec)
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from ..database.repositories.key_fact_repository import get_key_fact_repository
|
from ..database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
|
from ..database.repositories.related_files_repository import get_related_files_repository
|
||||||
from ..database.repositories.research_note_repository import get_research_note_repository
|
from ..database.repositories.research_note_repository import get_research_note_repository
|
||||||
from ..llm import initialize_expert_llm
|
from ..llm import initialize_expert_llm
|
||||||
from ..model_formatters import format_key_facts_dict
|
from ..model_formatters import format_key_facts_dict
|
||||||
|
|
@ -154,7 +155,7 @@ def ask_expert(question: str) -> str:
|
||||||
global expert_context
|
global expert_context
|
||||||
|
|
||||||
# Get all content first
|
# Get all content first
|
||||||
file_paths = list(_global_memory["related_files"].values())
|
file_paths = list(get_related_files_repository().get_all().values())
|
||||||
related_contents = read_related_files(file_paths)
|
related_contents = read_related_files(file_paths)
|
||||||
# Get key snippets directly from repository and format using the formatter
|
# Get key snippets directly from repository and format using the formatter
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,8 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
try:
|
|
||||||
import magic
|
|
||||||
except ImportError:
|
|
||||||
magic = None
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
from ra_aid.utils.file_utils import is_binary_file
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
@ -44,10 +40,11 @@ console = Console()
|
||||||
# Import repositories using the get_* functions
|
# Import repositories using the get_* functions
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
|
|
||||||
|
# Import the related files repository
|
||||||
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
|
|
||||||
# Global memory store
|
# Global memory store
|
||||||
_global_memory: Dict[str, Any] = {
|
_global_memory: Dict[str, Any] = {
|
||||||
"related_files": {}, # Dict[int, str] - ID to filepath mapping
|
|
||||||
"related_file_id_counter": 1, # Counter for generating unique file IDs
|
|
||||||
"agent_depth": 0,
|
"agent_depth": 0,
|
||||||
"work_log": [], # List[WorkLogEntry] - Timestamped work events
|
"work_log": [], # List[WorkLogEntry] - Timestamped work events
|
||||||
}
|
}
|
||||||
|
|
@ -302,8 +299,8 @@ def get_related_files() -> List[str]:
|
||||||
Returns:
|
Returns:
|
||||||
List of formatted strings in the format 'ID#X path/to/file.py'
|
List of formatted strings in the format 'ID#X path/to/file.py'
|
||||||
"""
|
"""
|
||||||
files = _global_memory["related_files"]
|
repo = get_related_files_repository()
|
||||||
return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(files.items())]
|
return repo.format_related_files()
|
||||||
|
|
||||||
|
|
||||||
@tool("emit_related_files")
|
@tool("emit_related_files")
|
||||||
|
|
@ -313,6 +310,7 @@ def emit_related_files(files: List[str]) -> str:
|
||||||
Args:
|
Args:
|
||||||
files: List of file paths to add
|
files: List of file paths to add
|
||||||
"""
|
"""
|
||||||
|
repo = get_related_files_repository()
|
||||||
results = []
|
results = []
|
||||||
added_files = []
|
added_files = []
|
||||||
invalid_paths = []
|
invalid_paths = []
|
||||||
|
|
@ -344,27 +342,20 @@ def emit_related_files(files: List[str]) -> str:
|
||||||
results.append(f"Skipped binary file: '{file}'")
|
results.append(f"Skipped binary file: '{file}'")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Normalize the path
|
# Add file to repository
|
||||||
normalized_path = os.path.abspath(file)
|
file_id = repo.add_file(file)
|
||||||
|
|
||||||
# Check if normalized path already exists in values
|
if file_id is not None:
|
||||||
existing_id = None
|
# Check if it's a new file by comparing with previous results
|
||||||
for fid, fpath in _global_memory["related_files"].items():
|
is_new_file = True
|
||||||
if fpath == normalized_path:
|
for r in results:
|
||||||
existing_id = fid
|
if r.startswith(f"File ID #{file_id}:"):
|
||||||
break
|
is_new_file = False
|
||||||
|
break
|
||||||
|
|
||||||
if existing_id is not None:
|
if is_new_file:
|
||||||
# File exists, use existing ID
|
added_files.append((file_id, file)) # Keep original path for display
|
||||||
results.append(f"File ID #{existing_id}: {file}")
|
|
||||||
else:
|
|
||||||
# New file, assign new ID
|
|
||||||
file_id = _global_memory["related_file_id_counter"]
|
|
||||||
_global_memory["related_file_id_counter"] += 1
|
|
||||||
|
|
||||||
# Store normalized path with ID
|
|
||||||
_global_memory["related_files"][file_id] = normalized_path
|
|
||||||
added_files.append((file_id, file)) # Keep original path for display
|
|
||||||
results.append(f"File ID #{file_id}: {file}")
|
results.append(f"File ID #{file_id}: {file}")
|
||||||
|
|
||||||
# Rich output - single consolidated panel for added files
|
# Rich output - single consolidated panel for added files
|
||||||
|
|
@ -421,57 +412,7 @@ def log_work_event(event: str) -> str:
|
||||||
return f"Event logged: {event}"
|
return f"Event logged: {event}"
|
||||||
|
|
||||||
|
|
||||||
def is_binary_file(filepath):
|
|
||||||
"""Check if a file is binary using magic library if available."""
|
|
||||||
# First check if file is empty
|
|
||||||
if os.path.getsize(filepath) == 0:
|
|
||||||
return False # Empty files are not binary
|
|
||||||
|
|
||||||
if magic:
|
|
||||||
try:
|
|
||||||
mime = magic.from_file(filepath, mime=True)
|
|
||||||
file_type = magic.from_file(filepath)
|
|
||||||
|
|
||||||
# If MIME type starts with 'text/', it's likely a text file
|
|
||||||
if mime.startswith("text/"):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Also consider 'application/x-python' and similar script types as text
|
|
||||||
if any(mime.startswith(prefix) for prefix in ['application/x-python', 'application/javascript']):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check for common text file descriptors
|
|
||||||
text_indicators = ["text", "script", "xml", "json", "yaml", "markdown", "HTML"]
|
|
||||||
if any(indicator.lower() in file_type.lower() for indicator in text_indicators):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# If none of the text indicators are present, assume it's binary
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return _is_binary_fallback(filepath)
|
|
||||||
else:
|
|
||||||
return _is_binary_fallback(filepath)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_binary_fallback(filepath):
|
|
||||||
"""Fallback method to detect binary files without using magic."""
|
|
||||||
try:
|
|
||||||
# First check if file is empty
|
|
||||||
if os.path.getsize(filepath) == 0:
|
|
||||||
return False # Empty files are not binary
|
|
||||||
|
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
|
||||||
chunk = f.read(1024)
|
|
||||||
|
|
||||||
# Check for null bytes which indicate binary content
|
|
||||||
if "\0" in chunk:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If we can read it as text without errors, it's probably not binary
|
|
||||||
return False
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
# If we can't decode as UTF-8, it's likely binary
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def get_work_log() -> str:
|
def get_work_log() -> str:
|
||||||
|
|
@ -518,17 +459,18 @@ def reset_work_log() -> str:
|
||||||
|
|
||||||
@tool("deregister_related_files")
|
@tool("deregister_related_files")
|
||||||
def deregister_related_files(file_ids: List[int]) -> str:
|
def deregister_related_files(file_ids: List[int]) -> str:
|
||||||
"""Delete multiple related files from global memory by their IDs.
|
"""Delete multiple related files by their IDs.
|
||||||
Silently skips any IDs that don't exist.
|
Silently skips any IDs that don't exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_ids: List of file IDs to delete
|
file_ids: List of file IDs to delete
|
||||||
"""
|
"""
|
||||||
|
repo = get_related_files_repository()
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
if file_id in _global_memory["related_files"]:
|
deleted_file = repo.remove_file(file_id)
|
||||||
# Delete the file reference
|
if deleted_file:
|
||||||
deleted_file = _global_memory["related_files"].pop(file_id)
|
|
||||||
success_msg = (
|
success_msg = (
|
||||||
f"Successfully removed related file #{file_id}: {deleted_file}"
|
f"Successfully removed related file #{file_id}: {deleted_file}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from ra_aid.models_params import DEFAULT_BASE_LATENCY, models_params
|
||||||
from ra_aid.proc.interactive import run_interactive_command
|
from ra_aid.proc.interactive import run_interactive_command
|
||||||
from ra_aid.text.processing import truncate_output
|
from ra_aid.text.processing import truncate_output
|
||||||
from ra_aid.tools.memory import _global_memory, log_work_event
|
from ra_aid.tools.memory import _global_memory, log_work_event
|
||||||
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
@ -91,13 +92,18 @@ def run_programming_task(
|
||||||
|
|
||||||
# Get combined list of files (explicit + related) with normalized paths
|
# Get combined list of files (explicit + related) with normalized paths
|
||||||
# and deduplicated using set operations
|
# and deduplicated using set operations
|
||||||
|
related_files_paths = []
|
||||||
|
try:
|
||||||
|
repo = get_related_files_repository()
|
||||||
|
related_files_paths = list(repo.get_all().values())
|
||||||
|
logger.debug("Retrieved related files from repository")
|
||||||
|
except RuntimeError as e:
|
||||||
|
# Repository not initialized
|
||||||
|
logger.warning(f"Failed to get related files repository: {e}")
|
||||||
|
|
||||||
files_to_use = list(
|
files_to_use = list(
|
||||||
{os.path.abspath(f) for f in (files or [])}
|
{os.path.abspath(f) for f in (files or [])}
|
||||||
| {
|
| {os.path.abspath(f) for f in related_files_paths}
|
||||||
os.path.abspath(f)
|
|
||||||
for f in _global_memory["related_files"].values()
|
|
||||||
if "related_files" in _global_memory
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add config file if specified
|
# Add config file if specified
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Utility functions for the ra-aid project."""
|
||||||
|
|
||||||
|
from .file_utils import is_binary_file
|
||||||
|
|
||||||
|
__all__ = ["is_binary_file"]
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
"""Utility functions for file operations."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
try:
|
||||||
|
import magic
|
||||||
|
except ImportError:
|
||||||
|
magic = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_binary_file(filepath):
|
||||||
|
"""Check if a file is binary using magic library if available."""
|
||||||
|
# First check if file is empty
|
||||||
|
if os.path.getsize(filepath) == 0:
|
||||||
|
return False # Empty files are not binary
|
||||||
|
|
||||||
|
if magic:
|
||||||
|
try:
|
||||||
|
mime = magic.from_file(filepath, mime=True)
|
||||||
|
file_type = magic.from_file(filepath)
|
||||||
|
|
||||||
|
# If MIME type starts with 'text/', it's likely a text file
|
||||||
|
if mime.startswith("text/"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Also consider 'application/x-python' and similar script types as text
|
||||||
|
if any(mime.startswith(prefix) for prefix in ['application/x-python', 'application/javascript']):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for common text file descriptors
|
||||||
|
text_indicators = ["text", "script", "xml", "json", "yaml", "markdown", "HTML"]
|
||||||
|
if any(indicator.lower() in file_type.lower() for indicator in text_indicators):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If none of the text indicators are present, assume it's binary
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return _is_binary_fallback(filepath)
|
||||||
|
else:
|
||||||
|
return _is_binary_fallback(filepath)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_binary_fallback(filepath):
|
||||||
|
"""Fallback method to detect binary files without using magic."""
|
||||||
|
try:
|
||||||
|
# First check if file is empty
|
||||||
|
if os.path.getsize(filepath) == 0:
|
||||||
|
return False # Empty files are not binary
|
||||||
|
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
chunk = f.read(1024)
|
||||||
|
|
||||||
|
# Check for null bytes which indicate binary content
|
||||||
|
if "\0" in chunk:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If we can read it as text without errors, it's probably not binary
|
||||||
|
return False
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# If we can't decode as UTF-8, it's likely binary
|
||||||
|
return True
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""Unit tests for __main__.py argument parsing."""
|
"""Unit tests for __main__.py argument parsing."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
from ra_aid.__main__ import parse_arguments
|
from ra_aid.__main__ import parse_arguments
|
||||||
from ra_aid.config import DEFAULT_RECURSION_LIMIT
|
from ra_aid.config import DEFAULT_RECURSION_LIMIT
|
||||||
|
|
@ -12,8 +13,6 @@ def mock_dependencies(monkeypatch):
|
||||||
"""Mock all dependencies needed for main()."""
|
"""Mock all dependencies needed for main()."""
|
||||||
# Initialize global memory with necessary keys to prevent KeyError
|
# Initialize global memory with necessary keys to prevent KeyError
|
||||||
_global_memory.clear()
|
_global_memory.clear()
|
||||||
_global_memory["related_files"] = {}
|
|
||||||
_global_memory["related_file_id_counter"] = 1
|
|
||||||
_global_memory["agent_depth"] = 0
|
_global_memory["agent_depth"] = 0
|
||||||
_global_memory["work_log"] = []
|
_global_memory["work_log"] = []
|
||||||
_global_memory["config"] = {}
|
_global_memory["config"] = {}
|
||||||
|
|
@ -37,6 +36,28 @@ def mock_dependencies(monkeypatch):
|
||||||
monkeypatch.setattr("ra_aid.__main__.initialize_llm", mock_config_update)
|
monkeypatch.setattr("ra_aid.__main__.initialize_llm", mock_config_update)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_related_files_repository():
|
||||||
|
"""Mock the RelatedFilesRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.database.repositories.related_files_repository.related_files_repo_var') as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Create a dictionary to simulate stored files
|
||||||
|
related_files = {}
|
||||||
|
|
||||||
|
# Setup get_all method to return the files dict
|
||||||
|
mock_repo.get_all.return_value = related_files
|
||||||
|
|
||||||
|
# Setup format_related_files method
|
||||||
|
mock_repo.format_related_files.return_value = [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(related_files.items())]
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
def test_recursion_limit_in_global_config(mock_dependencies):
|
def test_recursion_limit_in_global_config(mock_dependencies):
|
||||||
"""Test that recursion limit is correctly set in global config."""
|
"""Test that recursion limit is correctly set in global config."""
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -122,8 +143,6 @@ def test_temperature_validation(mock_dependencies):
|
||||||
|
|
||||||
# Reset global memory for clean test
|
# Reset global memory for clean test
|
||||||
_global_memory.clear()
|
_global_memory.clear()
|
||||||
_global_memory["related_files"] = {}
|
|
||||||
_global_memory["related_file_id_counter"] = 1
|
|
||||||
_global_memory["agent_depth"] = 0
|
_global_memory["agent_depth"] = 0
|
||||||
_global_memory["work_log"] = []
|
_global_memory["work_log"] = []
|
||||||
_global_memory["config"] = {}
|
_global_memory["config"] = {}
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,57 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
from ra_aid.tools.programmer import (
|
from ra_aid.tools.programmer import (
|
||||||
get_aider_executable,
|
get_aider_executable,
|
||||||
parse_aider_flags,
|
parse_aider_flags,
|
||||||
run_programming_task,
|
run_programming_task,
|
||||||
)
|
)
|
||||||
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_related_files_repository():
|
||||||
|
"""Mock the RelatedFilesRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.database.repositories.related_files_repository.related_files_repo_var') as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Create a dictionary to simulate stored files
|
||||||
|
related_files = {}
|
||||||
|
|
||||||
|
# Setup get_all method to return the files dict
|
||||||
|
mock_repo.get_all.return_value = related_files
|
||||||
|
|
||||||
|
# Setup format_related_files method
|
||||||
|
mock_repo.format_related_files.return_value = [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(related_files.items())]
|
||||||
|
|
||||||
|
# Setup add_file method
|
||||||
|
def mock_add_file(filepath):
|
||||||
|
normalized_path = os.path.abspath(filepath)
|
||||||
|
# Check if path already exists
|
||||||
|
for file_id, path in related_files.items():
|
||||||
|
if path == normalized_path:
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
# Add new file
|
||||||
|
file_id = len(related_files) + 1
|
||||||
|
related_files[file_id] = normalized_path
|
||||||
|
return file_id
|
||||||
|
mock_repo.add_file.side_effect = mock_add_file
|
||||||
|
|
||||||
|
# Setup remove_file method
|
||||||
|
def mock_remove_file(file_id):
|
||||||
|
if file_id in related_files:
|
||||||
|
return related_files.pop(file_id)
|
||||||
|
return None
|
||||||
|
mock_repo.remove_file.side_effect = mock_remove_file
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
# Also patch the get_related_files_repository function
|
||||||
|
with patch('ra_aid.tools.programmer.get_related_files_repository', return_value=mock_repo):
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
# Test cases for parse_aider_flags function
|
# Test cases for parse_aider_flags function
|
||||||
test_cases = [
|
test_cases = [
|
||||||
|
|
@ -78,11 +125,11 @@ def test_parse_aider_flags(input_flags, expected, description):
|
||||||
assert result == expected, f"Failed test case: {description}"
|
assert result == expected, f"Failed test case: {description}"
|
||||||
|
|
||||||
|
|
||||||
def test_aider_config_flag(mocker):
|
def test_aider_config_flag(mocker, mock_related_files_repository):
|
||||||
"""Test that aider config flag is properly included in the command when specified."""
|
"""Test that aider config flag is properly included in the command when specified."""
|
||||||
|
# Mock config in global memory but not related files (using repository now)
|
||||||
mock_memory = {
|
mock_memory = {
|
||||||
"config": {"aider_config": "/path/to/config.yml"},
|
"config": {"aider_config": "/path/to/config.yml"},
|
||||||
"related_files": {},
|
|
||||||
}
|
}
|
||||||
mocker.patch("ra_aid.tools.programmer._global_memory", mock_memory)
|
mocker.patch("ra_aid.tools.programmer._global_memory", mock_memory)
|
||||||
|
|
||||||
|
|
@ -99,15 +146,15 @@ def test_aider_config_flag(mocker):
|
||||||
assert args[config_index + 1] == "/path/to/config.yml"
|
assert args[config_index + 1] == "/path/to/config.yml"
|
||||||
|
|
||||||
|
|
||||||
def test_path_normalization_and_deduplication(mocker, tmp_path):
|
def test_path_normalization_and_deduplication(mocker, tmp_path, mock_related_files_repository):
|
||||||
"""Test path normalization and deduplication in run_programming_task."""
|
"""Test path normalization and deduplication in run_programming_task."""
|
||||||
# Create a temporary test file
|
# Create a temporary test file
|
||||||
test_file = tmp_path / "test.py"
|
test_file = tmp_path / "test.py"
|
||||||
test_file.write_text("")
|
test_file.write_text("")
|
||||||
new_file = tmp_path / "new.py"
|
new_file = tmp_path / "new.py"
|
||||||
|
|
||||||
# Mock dependencies
|
# Mock dependencies - only need to mock config part of global memory now
|
||||||
mocker.patch("ra_aid.tools.programmer._global_memory", {"related_files": {}})
|
mocker.patch("ra_aid.tools.programmer._global_memory", {"config": {}})
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"ra_aid.tools.programmer.get_aider_executable", return_value="/path/to/aider"
|
"ra_aid.tools.programmer.get_aider_executable", return_value="/path/to/aider"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
|
import os
|
||||||
|
|
||||||
from ra_aid.tools.agent import (
|
from ra_aid.tools.agent import (
|
||||||
request_research,
|
request_research,
|
||||||
|
|
@ -10,21 +11,39 @@ from ra_aid.tools.agent import (
|
||||||
request_research_and_implementation,
|
request_research_and_implementation,
|
||||||
)
|
)
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def reset_memory():
|
def reset_memory():
|
||||||
"""Reset global memory before each test"""
|
"""Reset global memory before each test"""
|
||||||
_global_memory["related_files"] = {}
|
|
||||||
_global_memory["related_file_id_counter"] = 0
|
|
||||||
_global_memory["work_log"] = []
|
_global_memory["work_log"] = []
|
||||||
yield
|
yield
|
||||||
# Clean up after test
|
# Clean up after test
|
||||||
_global_memory["related_files"] = {}
|
|
||||||
_global_memory["related_file_id_counter"] = 0
|
|
||||||
_global_memory["work_log"] = []
|
_global_memory["work_log"] = []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_related_files_repository():
|
||||||
|
"""Mock the RelatedFilesRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.database.repositories.related_files_repository.related_files_repo_var') as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Create a dictionary to simulate stored files
|
||||||
|
related_files = {}
|
||||||
|
|
||||||
|
# Setup get_all method to return the files dict
|
||||||
|
mock_repo.get_all.return_value = related_files
|
||||||
|
|
||||||
|
# Setup format_related_files method
|
||||||
|
mock_repo.format_related_files.return_value = [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(related_files.items())]
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_functions():
|
def mock_functions():
|
||||||
"""Mock functions used in agent.py"""
|
"""Mock functions used in agent.py"""
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ from ra_aid.tools.memory import (
|
||||||
get_work_log,
|
get_work_log,
|
||||||
log_work_event,
|
log_work_event,
|
||||||
reset_work_log,
|
reset_work_log,
|
||||||
is_binary_file,
|
|
||||||
_is_binary_fallback,
|
|
||||||
)
|
)
|
||||||
|
from ra_aid.utils.file_utils import is_binary_file, _is_binary_fallback
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
from ra_aid.database.connection import DatabaseManager
|
from ra_aid.database.connection import DatabaseManager
|
||||||
from ra_aid.database.models import KeyFact
|
from ra_aid.database.models import KeyFact
|
||||||
|
|
||||||
|
|
@ -29,13 +29,9 @@ from ra_aid.database.models import KeyFact
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def reset_memory():
|
def reset_memory():
|
||||||
"""Reset global memory before each test"""
|
"""Reset global memory before each test"""
|
||||||
_global_memory["related_files"] = {}
|
|
||||||
_global_memory["related_file_id_counter"] = 0
|
|
||||||
_global_memory["work_log"] = []
|
_global_memory["work_log"] = []
|
||||||
yield
|
yield
|
||||||
# Clean up after test
|
# Clean up after test
|
||||||
_global_memory["related_files"] = {}
|
|
||||||
_global_memory["related_file_id_counter"] = 0
|
|
||||||
_global_memory["work_log"] = []
|
_global_memory["work_log"] = []
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -165,6 +161,66 @@ def mock_key_snippet_repository():
|
||||||
yield memory_mock_repo
|
yield memory_mock_repo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_related_files_repository():
|
||||||
|
"""Mock the RelatedFilesRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.tools.memory.get_related_files_repository') as mock_repo:
|
||||||
|
# Setup the mock repository to behave like the original, but using memory
|
||||||
|
related_files = {} # Local in-memory storage
|
||||||
|
id_counter = 0
|
||||||
|
|
||||||
|
# Mock add_file method
|
||||||
|
def mock_add_file(filepath):
|
||||||
|
nonlocal id_counter
|
||||||
|
# Check if normalized path already exists in values
|
||||||
|
normalized_path = os.path.abspath(filepath)
|
||||||
|
for file_id, path in related_files.items():
|
||||||
|
if path == normalized_path:
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
# First check if path exists
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Then check if it's a directory
|
||||||
|
if os.path.isdir(filepath):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Validate it's a regular file
|
||||||
|
if not os.path.isfile(filepath):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if it's a binary file (don't actually check in tests)
|
||||||
|
# We'll mock is_binary_file separately when needed
|
||||||
|
|
||||||
|
# Add new file
|
||||||
|
file_id = id_counter
|
||||||
|
id_counter += 1
|
||||||
|
related_files[file_id] = normalized_path
|
||||||
|
|
||||||
|
return file_id
|
||||||
|
mock_repo.return_value.add_file.side_effect = mock_add_file
|
||||||
|
|
||||||
|
# Mock get_all method
|
||||||
|
def mock_get_all():
|
||||||
|
return related_files.copy()
|
||||||
|
mock_repo.return_value.get_all.side_effect = mock_get_all
|
||||||
|
|
||||||
|
# Mock remove_file method
|
||||||
|
def mock_remove_file(file_id):
|
||||||
|
if file_id in related_files:
|
||||||
|
return related_files.pop(file_id)
|
||||||
|
return None
|
||||||
|
mock_repo.return_value.remove_file.side_effect = mock_remove_file
|
||||||
|
|
||||||
|
# Mock format_related_files method
|
||||||
|
def mock_format_related_files():
|
||||||
|
return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(related_files.items())]
|
||||||
|
mock_repo.return_value.format_related_files.side_effect = mock_format_related_files
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
def test_emit_key_facts_single_fact(reset_memory, mock_repository):
|
def test_emit_key_facts_single_fact(reset_memory, mock_repository):
|
||||||
"""Test emitting a single key fact using emit_key_facts"""
|
"""Test emitting a single key fact using emit_key_facts"""
|
||||||
# Test with single fact
|
# Test with single fact
|
||||||
|
|
@ -177,9 +233,6 @@ def test_emit_key_facts_single_fact(reset_memory, mock_repository):
|
||||||
|
|
||||||
def test_get_memory_value_other_types(reset_memory):
|
def test_get_memory_value_other_types(reset_memory):
|
||||||
"""Test get_memory_value remains compatible with other memory types"""
|
"""Test get_memory_value remains compatible with other memory types"""
|
||||||
# Test with empty list
|
|
||||||
assert get_memory_value("plans") == ""
|
|
||||||
|
|
||||||
# Test with non-existent key
|
# Test with non-existent key
|
||||||
assert get_memory_value("nonexistent") == ""
|
assert get_memory_value("nonexistent") == ""
|
||||||
|
|
||||||
|
|
@ -263,7 +316,7 @@ def test_emit_key_facts_triggers_cleaner(reset_memory, mock_repository):
|
||||||
"""Test that emit_key_facts triggers the cleaner agent when there are more than 30 facts"""
|
"""Test that emit_key_facts triggers the cleaner agent when there are more than 30 facts"""
|
||||||
# Setup mock repository to return more than 30 facts
|
# Setup mock repository to return more than 30 facts
|
||||||
facts = []
|
facts = []
|
||||||
for i in range(31):
|
for i in range(51):
|
||||||
facts.append(MagicMock(id=i, content=f"Test fact {i}", human_input_id=None))
|
facts.append(MagicMock(id=i, content=f"Test fact {i}", human_input_id=None))
|
||||||
|
|
||||||
# Mock the get_all method to return more than 30 facts
|
# Mock the get_all method to return more than 30 facts
|
||||||
|
|
@ -407,7 +460,7 @@ def test_delete_key_snippets_empty(mock_log_work_event, reset_memory, mock_key_s
|
||||||
mock_key_snippet_repository.return_value.delete.assert_not_called()
|
mock_key_snippet_repository.return_value.delete.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_emit_related_files_basic(reset_memory, tmp_path):
|
def test_emit_related_files_basic(reset_memory, mock_related_files_repository, tmp_path):
|
||||||
"""Test basic adding of files with ID tracking"""
|
"""Test basic adding of files with ID tracking"""
|
||||||
# Create test files
|
# Create test files
|
||||||
test_file = tmp_path / "test.py"
|
test_file = tmp_path / "test.py"
|
||||||
|
|
@ -420,23 +473,26 @@ def test_emit_related_files_basic(reset_memory, tmp_path):
|
||||||
# Test adding single file
|
# Test adding single file
|
||||||
result = emit_related_files.invoke({"files": [str(test_file)]})
|
result = emit_related_files.invoke({"files": [str(test_file)]})
|
||||||
assert result == "Files noted."
|
assert result == "Files noted."
|
||||||
assert _global_memory["related_files"][0] == str(test_file)
|
# Verify file was added using the repository
|
||||||
|
mock_related_files_repository.return_value.add_file.assert_called_with(str(test_file))
|
||||||
|
|
||||||
# Test adding multiple files
|
# Test adding multiple files
|
||||||
result = emit_related_files.invoke({"files": [str(main_file), str(utils_file)]})
|
result = emit_related_files.invoke({"files": [str(main_file), str(utils_file)]})
|
||||||
assert result == "Files noted."
|
assert result == "Files noted."
|
||||||
# Verify both files exist in related_files
|
# Verify both files were added
|
||||||
values = list(_global_memory["related_files"].values())
|
mock_related_files_repository.return_value.add_file.assert_any_call(str(main_file))
|
||||||
assert str(main_file) in values
|
mock_related_files_repository.return_value.add_file.assert_any_call(str(utils_file))
|
||||||
assert str(utils_file) in values
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_related_files_empty(reset_memory):
|
def test_get_related_files_empty(reset_memory, mock_related_files_repository):
|
||||||
"""Test getting related files when none added"""
|
"""Test getting related files when none added"""
|
||||||
|
# Mock empty format_related_files result
|
||||||
|
mock_related_files_repository.return_value.format_related_files.return_value = []
|
||||||
assert get_related_files() == []
|
assert get_related_files() == []
|
||||||
|
mock_related_files_repository.return_value.format_related_files.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_emit_related_files_duplicates(reset_memory, tmp_path):
|
def test_emit_related_files_duplicates(reset_memory, mock_related_files_repository, tmp_path):
|
||||||
"""Test that duplicate files return existing IDs with proper formatting"""
|
"""Test that duplicate files return existing IDs with proper formatting"""
|
||||||
# Create test files
|
# Create test files
|
||||||
test_file = tmp_path / "test.py"
|
test_file = tmp_path / "test.py"
|
||||||
|
|
@ -446,46 +502,34 @@ def test_emit_related_files_duplicates(reset_memory, tmp_path):
|
||||||
new_file = tmp_path / "new.py"
|
new_file = tmp_path / "new.py"
|
||||||
new_file.write_text("# New file")
|
new_file.write_text("# New file")
|
||||||
|
|
||||||
|
# Mock add_file to return consistent IDs
|
||||||
|
def mock_add_file(filepath):
|
||||||
|
if "test.py" in filepath:
|
||||||
|
return 0
|
||||||
|
elif "main.py" in filepath:
|
||||||
|
return 1
|
||||||
|
elif "new.py" in filepath:
|
||||||
|
return 2
|
||||||
|
return None
|
||||||
|
mock_related_files_repository.return_value.add_file.side_effect = mock_add_file
|
||||||
|
|
||||||
# Add initial files
|
# Add initial files
|
||||||
result1 = emit_related_files.invoke({"files": [str(test_file), str(main_file)]})
|
result1 = emit_related_files.invoke({"files": [str(test_file), str(main_file)]})
|
||||||
assert result1 == "Files noted."
|
assert result1 == "Files noted."
|
||||||
_first_id = 0 # ID of test.py
|
|
||||||
|
|
||||||
# Try adding duplicates
|
# Try adding duplicates
|
||||||
result2 = emit_related_files.invoke({"files": [str(test_file)]})
|
result2 = emit_related_files.invoke({"files": [str(test_file)]})
|
||||||
assert result2 == "Files noted."
|
assert result2 == "Files noted."
|
||||||
assert len(_global_memory["related_files"]) == 2 # Count should not increase
|
|
||||||
|
|
||||||
# Try mix of new and duplicate files
|
# Try mix of new and duplicate files
|
||||||
result = emit_related_files.invoke({"files": [str(test_file), str(new_file)]})
|
result = emit_related_files.invoke({"files": [str(test_file), str(new_file)]})
|
||||||
assert result == "Files noted."
|
assert result == "Files noted."
|
||||||
assert len(_global_memory["related_files"]) == 3
|
|
||||||
|
# Verify calls to add_file - should be called for each file (even duplicates)
|
||||||
|
assert mock_related_files_repository.return_value.add_file.call_count == 5
|
||||||
|
|
||||||
|
|
||||||
def test_related_files_id_tracking(reset_memory, tmp_path):
|
def test_deregister_related_files(reset_memory, mock_related_files_repository, tmp_path):
|
||||||
"""Test ID assignment and counter functionality for related files"""
|
|
||||||
# Create test files
|
|
||||||
file1 = tmp_path / "file1.py"
|
|
||||||
file1.write_text("# File 1")
|
|
||||||
file2 = tmp_path / "file2.py"
|
|
||||||
file2.write_text("# File 2")
|
|
||||||
|
|
||||||
# Add first file
|
|
||||||
result = emit_related_files.invoke({"files": [str(file1)]})
|
|
||||||
assert result == "Files noted."
|
|
||||||
assert _global_memory["related_file_id_counter"] == 1
|
|
||||||
|
|
||||||
# Add second file
|
|
||||||
result = emit_related_files.invoke({"files": [str(file2)]})
|
|
||||||
assert result == "Files noted."
|
|
||||||
assert _global_memory["related_file_id_counter"] == 2
|
|
||||||
|
|
||||||
# Verify all files stored correctly
|
|
||||||
assert _global_memory["related_files"][0] == str(file1)
|
|
||||||
assert _global_memory["related_files"][1] == str(file2)
|
|
||||||
|
|
||||||
|
|
||||||
def test_deregister_related_files(reset_memory, tmp_path):
|
|
||||||
"""Test deleting related files"""
|
"""Test deleting related files"""
|
||||||
# Create test files
|
# Create test files
|
||||||
file1 = tmp_path / "file1.py"
|
file1 = tmp_path / "file1.py"
|
||||||
|
|
@ -495,276 +539,108 @@ def test_deregister_related_files(reset_memory, tmp_path):
|
||||||
file3 = tmp_path / "file3.py"
|
file3 = tmp_path / "file3.py"
|
||||||
file3.write_text("# File 3")
|
file3.write_text("# File 3")
|
||||||
|
|
||||||
# Add test files
|
# Mock remove_file to return file paths for existing IDs
|
||||||
emit_related_files.invoke({"files": [str(file1), str(file2), str(file3)]})
|
def mock_remove_file(file_id):
|
||||||
|
if file_id == 0:
|
||||||
|
return str(file1)
|
||||||
|
elif file_id == 1:
|
||||||
|
return str(file2)
|
||||||
|
elif file_id == 2:
|
||||||
|
return str(file3)
|
||||||
|
return None
|
||||||
|
mock_related_files_repository.return_value.remove_file.side_effect = mock_remove_file
|
||||||
|
|
||||||
# Delete middle file
|
# Delete middle file
|
||||||
result = deregister_related_files.invoke({"file_ids": [1]})
|
result = deregister_related_files.invoke({"file_ids": [1]})
|
||||||
assert result == "Files noted."
|
assert result == "Files noted."
|
||||||
assert 1 not in _global_memory["related_files"]
|
mock_related_files_repository.return_value.remove_file.assert_called_with(1)
|
||||||
assert len(_global_memory["related_files"]) == 2
|
|
||||||
|
|
||||||
# Delete multiple files including non-existent ID
|
# Delete multiple files including non-existent ID
|
||||||
result = deregister_related_files.invoke({"file_ids": [0, 2, 999]})
|
result = deregister_related_files.invoke({"file_ids": [0, 2, 999]})
|
||||||
assert result == "Files noted."
|
assert result == "Files noted."
|
||||||
assert len(_global_memory["related_files"]) == 0
|
mock_related_files_repository.return_value.remove_file.assert_any_call(0)
|
||||||
|
mock_related_files_repository.return_value.remove_file.assert_any_call(2)
|
||||||
# Counter should remain unchanged after deletions
|
mock_related_files_repository.return_value.remove_file.assert_any_call(999)
|
||||||
assert _global_memory["related_file_id_counter"] == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_related_files_duplicates(reset_memory, tmp_path):
|
def test_emit_related_files_path_normalization(reset_memory, mock_related_files_repository, tmp_path):
|
||||||
"""Test duplicate file handling returns same ID"""
|
"""Test that emit_related_files normalization works correctly"""
|
||||||
# Create test file
|
|
||||||
test_file = tmp_path / "test.py"
|
|
||||||
test_file.write_text("# Test file")
|
|
||||||
|
|
||||||
# Add initial file
|
|
||||||
result1 = emit_related_files.invoke({"files": [str(test_file)]})
|
|
||||||
assert result1 == "Files noted."
|
|
||||||
|
|
||||||
# Add same file again
|
|
||||||
result2 = emit_related_files.invoke({"files": [str(test_file)]})
|
|
||||||
assert result2 == "Files noted."
|
|
||||||
|
|
||||||
# Verify only one entry exists
|
|
||||||
assert len(_global_memory["related_files"]) == 1
|
|
||||||
assert _global_memory["related_file_id_counter"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_emit_related_files_with_directory(reset_memory, tmp_path):
|
|
||||||
"""Test that directories and non-existent paths are rejected while valid files are added"""
|
|
||||||
# Create test directory and file
|
|
||||||
test_dir = tmp_path / "test_dir"
|
|
||||||
test_dir.mkdir()
|
|
||||||
test_file = tmp_path / "test_file.txt"
|
|
||||||
test_file.write_text("test content")
|
|
||||||
nonexistent = tmp_path / "does_not_exist.txt"
|
|
||||||
|
|
||||||
# Try to emit directory, nonexistent path, and valid file
|
|
||||||
result = emit_related_files.invoke(
|
|
||||||
{"files": [str(test_dir), str(nonexistent), str(test_file)]}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify result is the standard message
|
|
||||||
assert result == "Files noted."
|
|
||||||
|
|
||||||
# Verify directory and nonexistent not added but valid file was
|
|
||||||
assert len(_global_memory["related_files"]) == 1
|
|
||||||
values = list(_global_memory["related_files"].values())
|
|
||||||
assert str(test_file) in values
|
|
||||||
assert str(test_dir) not in values
|
|
||||||
assert str(nonexistent) not in values
|
|
||||||
|
|
||||||
|
|
||||||
def test_related_files_formatting(reset_memory, tmp_path):
|
|
||||||
"""Test related files output string formatting"""
|
|
||||||
# Create test files
|
|
||||||
file1 = tmp_path / "file1.py"
|
|
||||||
file1.write_text("# File 1")
|
|
||||||
file2 = tmp_path / "file2.py"
|
|
||||||
file2.write_text("# File 2")
|
|
||||||
|
|
||||||
# Add some files
|
|
||||||
emit_related_files.invoke({"files": [str(file1), str(file2)]})
|
|
||||||
|
|
||||||
# Get formatted output
|
|
||||||
output = get_memory_value("related_files")
|
|
||||||
# Expect just the IDs on separate lines
|
|
||||||
expected = "0\n1"
|
|
||||||
assert output == expected
|
|
||||||
|
|
||||||
# Test empty case
|
|
||||||
_global_memory["related_files"] = {}
|
|
||||||
assert get_memory_value("related_files") == ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_emit_related_files_path_normalization(reset_memory, tmp_path):
|
|
||||||
"""Test that emit_related_files fails to detect duplicates with non-normalized paths"""
|
|
||||||
# Create a test file
|
# Create a test file
|
||||||
test_file = tmp_path / "file.txt"
|
test_file = tmp_path / "file.txt"
|
||||||
test_file.write_text("test content")
|
test_file.write_text("test content")
|
||||||
|
|
||||||
# Change to the temp directory so relative paths work
|
# Change to the temp directory so relative paths work
|
||||||
import os
|
|
||||||
|
|
||||||
original_dir = os.getcwd()
|
original_dir = os.getcwd()
|
||||||
os.chdir(tmp_path)
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Add file with absolute path
|
# Set up mock to test path normalization
|
||||||
|
def mock_add_file(filepath):
|
||||||
|
# The repository normalizes paths before comparing
|
||||||
|
# This mock simulates that behavior
|
||||||
|
normalized_path = os.path.abspath(filepath)
|
||||||
|
if normalized_path == os.path.abspath("file.txt"):
|
||||||
|
return 0
|
||||||
|
return None
|
||||||
|
mock_related_files_repository.return_value.add_file.side_effect = mock_add_file
|
||||||
|
|
||||||
|
# Add file with relative path
|
||||||
result1 = emit_related_files.invoke({"files": ["file.txt"]})
|
result1 = emit_related_files.invoke({"files": ["file.txt"]})
|
||||||
assert result1 == "Files noted."
|
assert result1 == "Files noted."
|
||||||
|
|
||||||
# Add same file with relative path - should get same ID due to path normalization
|
# Add same file with different relative path - should get same ID
|
||||||
result2 = emit_related_files.invoke({"files": ["./file.txt"]})
|
result2 = emit_related_files.invoke({"files": ["./file.txt"]})
|
||||||
assert result2 == "Files noted."
|
assert result2 == "Files noted."
|
||||||
|
|
||||||
# Verify only one normalized path entry exists
|
# Verify both calls to add_file were made
|
||||||
assert len(_global_memory["related_files"]) == 1
|
assert mock_related_files_repository.return_value.add_file.call_count == 2
|
||||||
assert os.path.abspath("file.txt") in _global_memory["related_files"].values()
|
|
||||||
finally:
|
finally:
|
||||||
# Restore original directory
|
# Restore original directory
|
||||||
os.chdir(original_dir)
|
os.chdir(original_dir)
|
||||||
|
|
||||||
|
|
||||||
@patch('ra_aid.agents.key_snippets_gc_agent.log_work_event')
|
@patch('ra_aid.tools.memory.is_binary_file')
|
||||||
def test_key_snippets_integration(mock_log_work_event, reset_memory, mock_key_snippet_repository):
|
def test_emit_related_files_binary_filtering(mock_is_binary, reset_memory, mock_related_files_repository, tmp_path):
|
||||||
"""Integration test for key snippets functionality"""
|
|
||||||
# Create test files
|
|
||||||
import tempfile
|
|
||||||
import os
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_path:
|
|
||||||
file1 = os.path.join(tmp_path, "file1.py")
|
|
||||||
with open(file1, 'w') as f:
|
|
||||||
f.write("def func1():\n pass")
|
|
||||||
|
|
||||||
file2 = os.path.join(tmp_path, "file2.py")
|
|
||||||
with open(file2, 'w') as f:
|
|
||||||
f.write("def func2():\n return True")
|
|
||||||
|
|
||||||
file3 = os.path.join(tmp_path, "file3.py")
|
|
||||||
with open(file3, 'w') as f:
|
|
||||||
f.write("class TestClass:\n pass")
|
|
||||||
|
|
||||||
# Initial snippets to add
|
|
||||||
snippets = [
|
|
||||||
{
|
|
||||||
"filepath": file1,
|
|
||||||
"line_number": 10,
|
|
||||||
"snippet": "def func1():\n pass",
|
|
||||||
"description": "First function",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"filepath": file2,
|
|
||||||
"line_number": 20,
|
|
||||||
"snippet": "def func2():\n return True",
|
|
||||||
"description": "Second function",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"filepath": file3,
|
|
||||||
"line_number": 30,
|
|
||||||
"snippet": "class TestClass:\n pass",
|
|
||||||
"description": "Test class",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add all snippets one by one
|
|
||||||
for i, snippet in enumerate(snippets):
|
|
||||||
result = emit_key_snippet.invoke({"snippet_info": snippet})
|
|
||||||
assert result == f"Snippet #{i} stored."
|
|
||||||
|
|
||||||
# Reset mock to clear call history
|
|
||||||
mock_key_snippet_repository.reset_mock()
|
|
||||||
|
|
||||||
# Delete some but not all snippets (0 and 2)
|
|
||||||
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
|
|
||||||
result = delete_key_snippets.invoke({"snippet_ids": [0, 2]})
|
|
||||||
assert result == "Snippets deleted."
|
|
||||||
|
|
||||||
# Reset mock again
|
|
||||||
mock_key_snippet_repository.reset_mock()
|
|
||||||
|
|
||||||
# Add new snippet
|
|
||||||
file4 = os.path.join(tmp_path, "file4.py")
|
|
||||||
with open(file4, 'w') as f:
|
|
||||||
f.write("def func4():\n return False")
|
|
||||||
|
|
||||||
new_snippet = {
|
|
||||||
"filepath": file4,
|
|
||||||
"line_number": 40,
|
|
||||||
"snippet": "def func4():\n return False",
|
|
||||||
"description": "Fourth function",
|
|
||||||
}
|
|
||||||
result = emit_key_snippet.invoke({"snippet_info": new_snippet})
|
|
||||||
assert result == "Snippet #3 stored."
|
|
||||||
|
|
||||||
# Verify create was called with correct params
|
|
||||||
mock_key_snippet_repository.return_value.create.assert_called_with(
|
|
||||||
filepath=file4,
|
|
||||||
line_number=40,
|
|
||||||
snippet="def func4():\n return False",
|
|
||||||
description="Fourth function",
|
|
||||||
human_input_id=ANY
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reset mock again
|
|
||||||
mock_key_snippet_repository.reset_mock()
|
|
||||||
|
|
||||||
# Delete remaining snippets
|
|
||||||
with patch('ra_aid.agents.key_snippets_gc_agent.get_key_snippet_repository', mock_key_snippet_repository):
|
|
||||||
result = delete_key_snippets.invoke({"snippet_ids": [1, 3]})
|
|
||||||
assert result == "Snippets deleted."
|
|
||||||
|
|
||||||
|
|
||||||
def test_emit_related_files_binary_filtering(reset_memory, monkeypatch):
|
|
||||||
"""Test that binary files are filtered out when adding related files"""
|
"""Test that binary files are filtered out when adding related files"""
|
||||||
import tempfile
|
# Create test files
|
||||||
import os
|
text_file1 = tmp_path / "text1.txt"
|
||||||
|
text_file1.write_text("Text file 1 content")
|
||||||
|
text_file2 = tmp_path / "text2.txt"
|
||||||
|
text_file2.write_text("Text file 2 content")
|
||||||
|
binary_file1 = tmp_path / "binary1.bin"
|
||||||
|
binary_file1.write_text("Binary file 1 content")
|
||||||
|
binary_file2 = tmp_path / "binary2.bin"
|
||||||
|
binary_file2.write_text("Binary file 2 content")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_path:
|
# Mock is_binary_file to identify our "binary" files
|
||||||
# Create test text files
|
def mock_binary_check(filepath):
|
||||||
text_file1 = os.path.join(tmp_path, "text1.txt")
|
return ".bin" in str(filepath)
|
||||||
with open(text_file1, 'w') as f:
|
mock_is_binary.side_effect = mock_binary_check
|
||||||
f.write("Text file 1 content")
|
|
||||||
|
|
||||||
text_file2 = os.path.join(tmp_path, "text2.txt")
|
# Call emit_related_files with mix of text and binary files
|
||||||
with open(text_file2, 'w') as f:
|
result = emit_related_files.invoke({
|
||||||
f.write("Text file 2 content")
|
"files": [
|
||||||
|
str(text_file1),
|
||||||
|
str(binary_file1),
|
||||||
|
str(text_file2),
|
||||||
|
str(binary_file2),
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
# Create test "binary" files
|
# Verify the result message
|
||||||
binary_file1 = os.path.join(tmp_path, "binary1.bin")
|
assert "Files noted." in result
|
||||||
with open(binary_file1, 'w') as f:
|
assert "Binary files skipped:" in result
|
||||||
f.write("Binary file 1 content")
|
|
||||||
|
|
||||||
binary_file2 = os.path.join(tmp_path, "binary2.bin")
|
# Verify repository calls - should only call add_file for text files
|
||||||
with open(binary_file2, 'w') as f:
|
# Binary files should be filtered out before reaching the repository
|
||||||
f.write("Binary file 2 content")
|
assert mock_related_files_repository.return_value.add_file.call_count == 2
|
||||||
|
mock_related_files_repository.return_value.add_file.assert_any_call(str(text_file1))
|
||||||
# Mock the is_binary_file function to identify our "binary" files
|
mock_related_files_repository.return_value.add_file.assert_any_call(str(text_file2))
|
||||||
def mock_is_binary_file(filepath):
|
|
||||||
return ".bin" in str(filepath)
|
|
||||||
|
|
||||||
# Apply the mock
|
|
||||||
import ra_aid.tools.memory
|
|
||||||
monkeypatch.setattr(ra_aid.tools.memory, "is_binary_file", mock_is_binary_file)
|
|
||||||
|
|
||||||
# Call emit_related_files with mix of text and binary files
|
|
||||||
result = emit_related_files.invoke(
|
|
||||||
{
|
|
||||||
"files": [
|
|
||||||
text_file1,
|
|
||||||
binary_file1,
|
|
||||||
text_file2,
|
|
||||||
binary_file2,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the result message mentions skipped binary files
|
|
||||||
assert "Files noted." in result
|
|
||||||
assert "Binary files skipped:" in result
|
|
||||||
assert binary_file1 in result
|
|
||||||
assert binary_file2 in result
|
|
||||||
|
|
||||||
# Verify only text files were added to related_files
|
|
||||||
assert len(_global_memory["related_files"]) == 2
|
|
||||||
file_values = list(_global_memory["related_files"].values())
|
|
||||||
assert text_file1 in file_values
|
|
||||||
assert text_file2 in file_values
|
|
||||||
assert binary_file1 not in file_values
|
|
||||||
assert binary_file2 not in file_values
|
|
||||||
|
|
||||||
# Verify counter is correct (only incremented for text files)
|
|
||||||
assert _global_memory["related_file_id_counter"] == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_binary_file_with_ascii(reset_memory, monkeypatch):
|
def test_is_binary_file_with_ascii():
|
||||||
"""Test that ASCII files are correctly identified as text files"""
|
"""Test that ASCII files are correctly identified as text files"""
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import ra_aid.tools.memory
|
|
||||||
|
|
||||||
# Create a test ASCII file
|
# Create a test ASCII file
|
||||||
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
|
||||||
|
|
@ -772,32 +648,22 @@ def test_is_binary_file_with_ascii(reset_memory, monkeypatch):
|
||||||
ascii_file_path = f.name
|
ascii_file_path = f.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Test with magic library if available
|
# Test real implementation with ASCII file
|
||||||
if ra_aid.tools.memory.magic:
|
is_binary = is_binary_file(ascii_file_path)
|
||||||
# Test real implementation with ASCII file
|
assert not is_binary, "ASCII file should not be identified as binary"
|
||||||
is_binary = ra_aid.tools.memory.is_binary_file(ascii_file_path)
|
|
||||||
assert not is_binary, "ASCII file should not be identified as binary"
|
|
||||||
|
|
||||||
# Test fallback implementation
|
# Test fallback implementation
|
||||||
# Mock magic to be None to force fallback implementation
|
is_binary_fallback = _is_binary_fallback(ascii_file_path)
|
||||||
monkeypatch.setattr(ra_aid.tools.memory, "magic", None)
|
assert not is_binary_fallback, "ASCII file should not be identified as binary with fallback method"
|
||||||
|
|
||||||
# Test fallback with ASCII file
|
|
||||||
is_binary = ra_aid.tools.memory.is_binary_file(ascii_file_path)
|
|
||||||
assert (
|
|
||||||
not is_binary
|
|
||||||
), "ASCII file should not be identified as binary with fallback method"
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up
|
# Clean up
|
||||||
if os.path.exists(ascii_file_path):
|
if os.path.exists(ascii_file_path):
|
||||||
os.unlink(ascii_file_path)
|
os.unlink(ascii_file_path)
|
||||||
|
|
||||||
|
|
||||||
def test_is_binary_file_with_null_bytes(reset_memory, monkeypatch):
|
def test_is_binary_file_with_null_bytes():
|
||||||
"""Test that files with null bytes are correctly identified as binary"""
|
"""Test that files with null bytes are correctly identified as binary"""
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import ra_aid.tools.memory
|
|
||||||
|
|
||||||
# Create a file with null bytes (binary content)
|
# Create a file with null bytes (binary content)
|
||||||
binary_file = tempfile.NamedTemporaryFile(delete=False)
|
binary_file = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
|
@ -805,21 +671,13 @@ def test_is_binary_file_with_null_bytes(reset_memory, monkeypatch):
|
||||||
binary_file.close()
|
binary_file.close()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Test with magic library if available
|
# Test real implementation with binary file
|
||||||
if ra_aid.tools.memory.magic:
|
is_binary = is_binary_file(binary_file.name)
|
||||||
# Test real implementation with binary file
|
assert is_binary, "File with null bytes should be identified as binary"
|
||||||
is_binary = ra_aid.tools.memory.is_binary_file(binary_file.name)
|
|
||||||
assert is_binary, "File with null bytes should be identified as binary"
|
|
||||||
|
|
||||||
# Test fallback implementation
|
# Test fallback implementation
|
||||||
# Mock magic to be None to force fallback implementation
|
is_binary_fallback = _is_binary_fallback(binary_file.name)
|
||||||
monkeypatch.setattr(ra_aid.tools.memory, "magic", None)
|
assert is_binary_fallback, "File with null bytes should be identified as binary with fallback method"
|
||||||
|
|
||||||
# Test fallback with binary file
|
|
||||||
is_binary = ra_aid.tools.memory.is_binary_file(binary_file.name)
|
|
||||||
assert (
|
|
||||||
is_binary
|
|
||||||
), "File with null bytes should be identified as binary with fallback method"
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up
|
# Clean up
|
||||||
if os.path.exists(binary_file.name):
|
if os.path.exists(binary_file.name):
|
||||||
|
|
@ -850,7 +708,7 @@ def test_python_file_detection():
|
||||||
import magic
|
import magic
|
||||||
if magic:
|
if magic:
|
||||||
# Only run this part of the test if magic is available
|
# Only run this part of the test if magic is available
|
||||||
with patch('ra_aid.tools.memory.magic') as mock_magic:
|
with patch('ra_aid.utils.file_utils.magic') as mock_magic:
|
||||||
# Mock magic to simulate the behavior that causes the issue
|
# Mock magic to simulate the behavior that causes the issue
|
||||||
mock_magic.from_file.side_effect = [
|
mock_magic.from_file.side_effect = [
|
||||||
"text/x-python", # First call with mime=True
|
"text/x-python", # First call with mime=True
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue