RA.Aid/ra_aid/tools/memory.py

642 lines
22 KiB
Python

import os
from typing import Any, Dict, List, Optional
from langchain_core.tools import tool
from ra_aid.utils.file_utils import is_binary_file
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from typing_extensions import TypedDict
from ra_aid.agent_context import (
mark_plan_completed,
mark_should_exit,
mark_task_completed,
)
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
from ra_aid.model_formatters import key_snippets_formatter
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
class SnippetInfo(TypedDict):
filepath: str
line_number: int
snippet: str
description: Optional[str]
console = Console()
# Import repositories using the get_* functions
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
# Import the related files repository
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
@tool("emit_research_notes")
def emit_research_notes(notes: str) -> str:
"""Use this when you have completed your research to share your notes in markdown format.
Keep your research notes information dense and no more than 300 words.
Args:
notes: REQUIRED The research notes to store
"""
# Try to get the latest human input
human_input_id = None
try:
human_input_repo = get_human_input_repository()
human_input_id = human_input_repo.get_most_recent_id()
except RuntimeError as e:
logger.warning(f"No HumanInputRepository available: {str(e)}")
except Exception as e:
logger.warning(f"Failed to get recent human input: {str(e)}")
try:
# Create note in database using repository
created_note = get_research_note_repository().create(notes, human_input_id=human_input_id)
note_id = created_note.id
# Format the note using the formatter
from ra_aid.model_formatters.research_notes_formatter import format_research_note
formatted_note = format_research_note(note_id, notes)
# Record to trajectory before displaying panel
try:
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="emit_research_notes",
tool_parameters={"notes": notes},
step_data={
"note_id": note_id,
"display_title": "Research Notes",
},
record_type="memory_operation",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
# Display formatted note
console.print(Panel(Markdown(formatted_note), title="🔍 Research Notes"))
log_work_event(f"Stored research note #{note_id}.")
# Check if we need to clean up notes (more than 30)
try:
all_notes = get_research_note_repository().get_all()
if len(all_notes) > 30:
# Trigger the research notes cleaner agent
try:
from ra_aid.agents.research_notes_gc_agent import run_research_notes_gc_agent
run_research_notes_gc_agent()
except Exception as e:
logger.error(f"Failed to run research notes cleaner: {str(e)}")
except RuntimeError as e:
logger.error(f"Failed to access research note repository: {str(e)}")
return f"Research note #{note_id} stored."
except RuntimeError as e:
logger.error(f"Failed to access research note repository: {str(e)}")
console.print(f"Error storing research note: {str(e)}", style="red")
return "Failed to store research note."
@tool("emit_key_facts")
def emit_key_facts(facts: List[str]) -> str:
"""Store multiple key facts about the project or current task in global memory.
Args:
facts: List of key facts to store
"""
results = []
# Try to get the latest human input
human_input_id = None
try:
human_input_repo = get_human_input_repository()
human_input_id = human_input_repo.get_most_recent_id()
except RuntimeError as e:
logger.warning(f"No HumanInputRepository available: {str(e)}")
except Exception as e:
logger.warning(f"Failed to get recent human input: {str(e)}")
for fact in facts:
try:
# Create fact in database using repository
created_fact = get_key_fact_repository().create(fact, human_input_id=human_input_id)
fact_id = created_fact.id
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
console.print(f"Error storing fact: {str(e)}", style="red")
continue
# Record to trajectory before displaying panel
try:
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="emit_key_facts",
tool_parameters={"facts": [fact]},
step_data={
"fact_id": fact_id,
"fact": fact,
"display_title": f"Key Fact #{fact_id}",
},
record_type="memory_operation",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
# Display panel with ID
console.print(
Panel(
Markdown(fact),
title=f"💡 Key Fact #{fact_id}",
border_style="bright_cyan",
)
)
# Add result message
results.append(f"Stored fact #{fact_id}: {fact}")
log_work_event(f"Stored {len(facts)} key facts.")
# Check if we need to clean up facts (more than 30)
try:
all_facts = get_key_fact_repository().get_all()
if len(all_facts) > 50:
# Trigger the key facts cleaner agent
try:
from ra_aid.agents.key_facts_gc_agent import run_key_facts_gc_agent
run_key_facts_gc_agent()
except Exception as e:
logger.error(f"Failed to run key facts cleaner: {str(e)}")
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
return "Facts stored."
@tool("emit_key_snippet")
def emit_key_snippet(snippet_info: SnippetInfo) -> str:
"""Store a single source code snippet in the database which represents key information.
Automatically adds the filepath of the snippet to related files.
This is for **existing**, or **just-written** files, not for things to be created in the future.
ONLY emit snippets if they will be relevant to UPCOMING work.
Focus on external interfaces and things that are very specific and relevant to UPCOMING work.
SNIPPETS SHOULD TYPICALLY BE MULTIPLE LINES, NOT SINGLE LINES, NOT ENTIRE FILES.
Args:
snippet_info: Dict with keys:
- filepath: Path to the source file
- line_number: Line number where the snippet starts
- snippet: The source code snippet text
- description: Optional description of the significance
"""
# Add filepath to related files
emit_related_files.invoke({"files": [snippet_info["filepath"]]})
# Try to get the latest human input
human_input_id = None
try:
human_input_repo = get_human_input_repository()
human_input_id = human_input_repo.get_most_recent_id()
except RuntimeError as e:
logger.warning(f"No HumanInputRepository available: {str(e)}")
except Exception as e:
logger.warning(f"Failed to get recent human input: {str(e)}")
# Create a new key snippet in the database
key_snippet = get_key_snippet_repository().create(
filepath=snippet_info["filepath"],
line_number=snippet_info["line_number"],
snippet=snippet_info["snippet"],
description=snippet_info["description"],
human_input_id=human_input_id,
)
# Get the snippet ID from the database record
snippet_id = key_snippet.id
# Format display text as markdown
display_text = [
"**Source Location**:",
f"- File: `{snippet_info['filepath']}`",
f"- Line: `{snippet_info['line_number']}`",
"", # Empty line before code block
"**Code**:",
"```python",
snippet_info["snippet"].rstrip(), # Remove trailing whitespace
"```",
]
if snippet_info["description"]:
display_text.extend(["", "**Description**:", snippet_info["description"]])
# Record to trajectory before displaying panel
try:
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="emit_key_snippet",
tool_parameters={
"snippet_info": {
"filepath": snippet_info["filepath"],
"line_number": snippet_info["line_number"],
"description": snippet_info["description"],
# Omit the full snippet content to avoid duplicating large text in the database
"snippet_length": len(snippet_info["snippet"])
}
},
step_data={
"snippet_id": snippet_id,
"filepath": snippet_info["filepath"],
"line_number": snippet_info["line_number"],
"display_title": f"Key Snippet #{snippet_id}",
},
record_type="memory_operation",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
# Display panel
console.print(
Panel(
Markdown("\n".join(display_text)),
title=f"📝 Key Snippet #{snippet_id}",
border_style="bright_cyan",
)
)
log_work_event(f"Stored code snippet #{snippet_id}.")
# Check if we need to clean up snippets (more than 20)
all_snippets = get_key_snippet_repository().get_all()
if len(all_snippets) > 35:
# Trigger the key snippets cleaner agent
try:
from ra_aid.agents.key_snippets_gc_agent import run_key_snippets_gc_agent
run_key_snippets_gc_agent()
except Exception as e:
logger.error(f"Failed to run key snippets cleaner: {str(e)}")
return f"Snippet #{snippet_id} stored."
@tool("one_shot_completed")
def one_shot_completed(message: str) -> str:
"""Signal that a one-shot task has been completed and execution should stop.
Only call this if you have already **fully** completed the original request.
Args:
message: Completion message to display
"""
mark_task_completed(message)
# Record to trajectory before displaying panel
human_input_id = None
try:
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="one_shot_completed",
tool_parameters={"message": message},
step_data={
"completion_message": message,
"display_title": "Task Completed",
},
record_type="task_completion",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
console.print(Panel(Markdown(message), title="✅ Task Completed"))
log_work_event(f"Task completed:\n\n{message}")
return "Completion noted."
@tool("task_completed")
def task_completed(message: str) -> str:
"""Mark the current task as completed with a completion message.
Args:
message: Message explaining how/why the task is complete
"""
mark_task_completed(message)
# Record to trajectory before displaying panel
human_input_id = None
try:
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="task_completed",
tool_parameters={"message": message},
step_data={
"completion_message": message,
"display_title": "Task Completed",
},
record_type="task_completion",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
console.print(Panel(Markdown(message), title="✅ Task Completed"))
log_work_event(f"Task completed:\n\n{message}")
return "Completion noted."
@tool("plan_implementation_completed")
def plan_implementation_completed(message: str) -> str:
"""Mark the entire implementation plan as completed.
Args:
message: Message explaining how the implementation plan was completed
"""
mark_should_exit(propagation_depth=1)
mark_plan_completed(message)
# Record to trajectory before displaying panel
human_input_id = None
try:
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="plan_implementation_completed",
tool_parameters={"message": message},
step_data={
"completion_message": message,
"display_title": "Plan Executed",
},
record_type="plan_completion",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
log_work_event(f"Completed implementation:\n\n{message}")
return "Plan completion noted."
def get_related_files() -> List[str]:
"""Get the current list of related files.
Returns:
List of formatted strings in the format 'ID#X path/to/file.py'
"""
repo = get_related_files_repository()
return repo.format_related_files()
@tool("emit_related_files")
def emit_related_files(files: List[str]) -> str:
"""Store multiple related files that tools should work with.
Args:
files: List of file paths to add
"""
repo = get_related_files_repository()
# Store the repository's ID counter value before adding any files
try:
initial_next_id = repo.get_next_id()
except (AttributeError, TypeError):
# Handle case where repo is mocked in tests
initial_next_id = 0 # Use a safe default for mocked environments
results = []
added_files = []
invalid_paths = []
binary_files = []
# Process files
for file in files:
# First check if path exists
if not os.path.exists(file):
invalid_paths.append(file)
results.append(f"Error: Path '{file}' does not exist")
continue
# Then check if it's a directory
if os.path.isdir(file):
invalid_paths.append(file)
results.append(f"Error: Path '{file}' is a directory, not a file")
continue
# Finally validate it's a regular file
if not os.path.isfile(file):
invalid_paths.append(file)
results.append(f"Error: Path '{file}' exists but is not a regular file")
continue
# Check if it's a binary file
if is_binary_file(file):
binary_files.append(file)
results.append(f"Skipped binary file: '{file}'")
continue
# Add file to repository
file_id = repo.add_file(file)
if file_id is not None:
# Check if it's a truly new file (ID >= initial_next_id)
try:
is_truly_new = file_id >= initial_next_id
except TypeError:
# Handle case where file_id or initial_next_id is mocked in tests
is_truly_new = True # Default to True in test environments
# Also check for duplicates within this function call
is_duplicate_in_call = False
for r in results:
if r.startswith(f"File ID #{file_id}:"):
is_duplicate_in_call = True
break
# Only add to added_files if it's truly new AND not a duplicate in this call
if is_truly_new and not is_duplicate_in_call:
added_files.append((file_id, file)) # Keep original path for display
results.append(f"File ID #{file_id}: {file}")
# Record to trajectory before displaying panel for added files
if added_files:
files_added_md = "\n".join(f"- `{file}`" for id, file in added_files)
md_content = f"**Files Noted:**\n{files_added_md}"
human_input_id = None
try:
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="emit_related_files",
tool_parameters={"files": files},
step_data={
"added_files": [file for _, file in added_files],
"added_file_ids": [file_id for file_id, _ in added_files],
"display_title": "Related Files Noted",
},
record_type="memory_operation",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
console.print(
Panel(
Markdown(md_content),
title="📁 Related Files Noted",
border_style="green",
)
)
# Record to trajectory before displaying panel for binary files
if binary_files:
binary_files_md = "\n".join(f"- `{file}`" for file in binary_files)
md_content = f"**Binary Files Skipped:**\n{binary_files_md}"
human_input_id = None
try:
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="emit_related_files",
tool_parameters={"files": files},
step_data={
"binary_files": binary_files,
"display_title": "Binary Files Not Added",
},
record_type="memory_operation",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
console.print(
Panel(
Markdown(md_content),
title="⚠️ Binary Files Not Added",
border_style="yellow",
)
)
# Return summary message
if binary_files:
binary_files_list = ", ".join(f"'{file}'" for file in binary_files)
return f"Files noted. Binary files skipped: {binary_files_list}"
else:
return "Files noted."
def log_work_event(event: str) -> str:
"""Add timestamped entry to work log.
Internal function used to track major events during agent execution.
Each entry is stored with an ISO format timestamp.
Args:
event: Description of the event to log
Returns:
Confirmation message
Note:
Entries can be retrieved with get_work_log() as markdown formatted text.
"""
try:
repo = get_work_log_repository()
repo.add_entry(event)
return f"Event logged: {event}"
except RuntimeError as e:
logger.error(f"Failed to access work log repository: {str(e)}")
return f"Failed to log event: {str(e)}"
def get_work_log() -> str:
"""Return formatted markdown of work log entries.
Returns:
Markdown formatted text with timestamps as headings and events as content,
or 'No work log entries' if the log is empty.
Example:
## 2024-12-23T11:39:10
Task #1 added: Create login form
"""
try:
repo = get_work_log_repository()
return repo.format_work_log()
except RuntimeError as e:
logger.error(f"Failed to access work log repository: {str(e)}")
return "No work log entries"
def reset_work_log() -> str:
"""Clear the work log.
Returns:
Confirmation message
Note:
This permanently removes all work log entries. The operation cannot be undone.
"""
try:
repo = get_work_log_repository()
repo.clear()
return "Work log cleared"
except RuntimeError as e:
logger.error(f"Failed to access work log repository: {str(e)}")
return f"Failed to clear work log: {str(e)}"
@tool("deregister_related_files")
def deregister_related_files(file_ids: List[int]) -> str:
"""Delete multiple related files by their IDs.
Silently skips any IDs that don't exist.
Args:
file_ids: List of file IDs to delete
"""
repo = get_related_files_repository()
results = []
for file_id in file_ids:
deleted_file = repo.remove_file(file_id)
if deleted_file:
success_msg = (
f"Successfully removed related file #{file_id}: {deleted_file}"
)
console.print(
Panel(
Markdown(success_msg),
title="File Reference Removed",
border_style="green",
)
)
results.append(success_msg)
return "Files noted."