include last human input in research prompt

This commit is contained in:
AI Christianson 2025-03-02 20:54:39 -05:00
parent 9a69bb173e
commit 80e8d9134b
1 changed files with 15 additions and 4 deletions

View File

@ -85,6 +85,7 @@ from ra_aid.tool_configs import (
) )
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.tools.memory import ( from ra_aid.tools.memory import (
_global_memory, _global_memory,
@ -99,6 +100,7 @@ logger = get_logger(__name__)
# Initialize key fact repository # Initialize key fact repository
key_fact_repository = KeyFactRepository() key_fact_repository = KeyFactRepository()
human_input_repository = HumanInputRepository()
@tool @tool
@ -390,6 +392,16 @@ def run_research_agent(
code_snippets = _global_memory.get("code_snippets", "") code_snippets = _global_memory.get("code_snippets", "")
related_files = _global_memory.get("related_files", "") related_files = _global_memory.get("related_files", "")
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
working_directory = os.getcwd()
# Get the last human input, if it exists
base_task = base_task_or_query
recent_inputs = human_input_repository.get_recent(1)
if recent_inputs and len(recent_inputs) > 0:
last_human_input = recent_inputs[0].content
base_task = f"<last human input>{last_human_input}</last human input>\n{base_task}"
try: try:
project_info = get_project_info(".", file_limit=2000) project_info = get_project_info(".", file_limit=2000)
formatted_project_info = format_project_info(project_info) formatted_project_info = format_project_info(project_info)
@ -397,13 +409,10 @@ def run_research_agent(
logger.warning(f"Failed to get project info: {e}") logger.warning(f"Failed to get project info: {e}")
formatted_project_info = "" formatted_project_info = ""
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
working_directory = os.getcwd()
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format( prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
current_date=current_date, current_date=current_date,
working_directory=working_directory, working_directory=working_directory,
base_task=base_task_or_query, base_task=base_task,
research_only_note=( research_only_note=(
"" ""
if research_only if research_only
@ -420,6 +429,8 @@ def run_research_agent(
new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "", new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "",
) )
print("HERE", prompt)
config = _global_memory.get("config", {}) if not config else config config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = { run_config = {