diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 35d5f07..e9e9f5b 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -85,6 +85,7 @@ from ra_aid.tool_configs import ( ) from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command from ra_aid.database.repositories.key_fact_repository import KeyFactRepository +from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.model_formatters import format_key_facts_dict from ra_aid.tools.memory import ( _global_memory, @@ -99,6 +100,7 @@ logger = get_logger(__name__) # Initialize key fact repository key_fact_repository = KeyFactRepository() +human_input_repository = HumanInputRepository() @tool @@ -390,6 +392,16 @@ def run_research_agent( code_snippets = _global_memory.get("code_snippets", "") related_files = _global_memory.get("related_files", "") + current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + working_directory = os.getcwd() + + # Get the last human input, if it exists + base_task = base_task_or_query + recent_inputs = human_input_repository.get_recent(1) + if recent_inputs and len(recent_inputs) > 0: + last_human_input = recent_inputs[0].content + base_task = f"{last_human_input}\n{base_task}" + try: project_info = get_project_info(".", file_limit=2000) formatted_project_info = format_project_info(project_info) @@ -397,13 +409,10 @@ def run_research_agent( logger.warning(f"Failed to get project info: {e}") formatted_project_info = "" - current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - working_directory = os.getcwd() - prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format( current_date=current_date, working_directory=working_directory, - base_task=base_task_or_query, + base_task=base_task, research_only_note=( "" if research_only @@ -420,6 +429,8 @@ def run_research_agent( new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "", ) + print("HERE", prompt) + config = _global_memory.get("config", {}) if not config else config recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) run_config = {