From f7523e86d8d9a33023b700801bb24837f027cbef Mon Sep 17 00:00:00 2001 From: user Date: Mon, 23 Dec 2024 14:05:59 -0500 Subject: [PATCH] Improve agent interruption UX by allowing user to specify feedback or exit the program entirely. --- ra_aid/__init__.py | 3 ++- ra_aid/__main__.py | 7 +++--- ra_aid/agent_utils.py | 2 +- ra_aid/console/__init__.py | 4 +-- ra_aid/console/formatting.py | 9 ++++--- ra_aid/tools/agent.py | 47 ++++++++++++++++++++++++++---------- 6 files changed, 48 insertions(+), 24 deletions(-) diff --git a/ra_aid/__init__.py b/ra_aid/__init__.py index 06aff82..d9f3fd2 100644 --- a/ra_aid/__init__.py +++ b/ra_aid/__init__.py @@ -1,5 +1,5 @@ from .__version__ import __version__ -from .console.formatting import print_stage_header, print_task_header, print_error +from .console.formatting import print_stage_header, print_task_header, print_error, print_interrupt from .console.output import print_agent_output from .text.processing import truncate_output from .agent_utils import run_agent_with_retry @@ -10,6 +10,7 @@ __all__ = [ 'print_agent_output', 'truncate_output', 'print_error', + 'print_interrupt', 'run_agent_with_retry', '__version__' ] diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 8b69b4f..4af97e0 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -3,13 +3,14 @@ import sys import uuid from rich.panel import Panel from rich.console import Console -from ra_aid.console.formatting import print_interrupt from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent from ra_aid.env import validate_environment from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value from ra_aid.tools.human import ask_human -from ra_aid import print_stage_header, print_error +from ra_aid import print_stage_header, print_error, print_interrupt +from ra_aid.tools.agent import CANCELLED_BY_USER_REASON +from ra_aid.tools.human import ask_human from ra_aid.agent_utils import ( run_agent_with_retry, run_research_agent, @@ -227,7 +228,7 @@ def main(): ) except KeyboardInterrupt: - print_interrupt("Operation cancelled by user") + print_interrupt(f"Operation cancelled: {CANCELLED_BY_USER_REASON}") sys.exit(1) if __name__ == "__main__": diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 9ac368a..ff1b9d3 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -10,7 +10,7 @@ import time from typing import Optional from langgraph.prebuilt import create_react_agent -from ra_aid.console.formatting import print_stage_header, print_error, print_interrupt +from ra_aid.console.formatting import print_stage_header, print_error from ra_aid.console.output import print_agent_output from ra_aid.tool_configs import ( get_implementation_tools, diff --git a/ra_aid/console/__init__.py b/ra_aid/console/__init__.py index 757483d..3522f88 100644 --- a/ra_aid/console/__init__.py +++ b/ra_aid/console/__init__.py @@ -1,4 +1,4 @@ -from .formatting import print_stage_header, print_task_header, print_error, console +from .formatting import print_stage_header, print_task_header, print_error, print_interrupt, console from .output import print_agent_output -__all__ = ['print_stage_header', 'print_task_header', 'print_agent_output', 'console', 'print_error'] +__all__ = ['print_stage_header', 'print_task_header', 'print_agent_output', 'console', 'print_error', 'print_interrupt'] diff --git a/ra_aid/console/formatting.py b/ra_aid/console/formatting.py index 9728f69..ef77904 100644 --- a/ra_aid/console/formatting.py +++ b/ra_aid/console/formatting.py @@ -50,10 +50,11 @@ def print_error(message: str) -> None: console.print(Panel(Markdown(message), title="Error", border_style="red bold")) def print_interrupt(message: str) -> None: - """Print an interruption message in a yellow-bordered panel with appropriate emoji. + """Print an interrupt message in a yellow-bordered panel with stop emoji. Args: - message: The interruption message to display (supports Markdown formatting) + message: The interrupt message to display (supports Markdown formatting) """ - print() # Give space for "^C" - console.print(Panel(Markdown(message), title="Interrupted", border_style="yellow bold")) + print() # Add spacing for ^C + console.print(Panel(Markdown(message), title="⛔ Interrupt", border_style="yellow bold")) + diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 31f244b..85cf604 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -7,8 +7,9 @@ from typing_extensions import TypeAlias ResearchResult = Dict[str, Union[str, bool, Dict[int, Any], List[Any], None]] from rich.console import Console from ra_aid.tools.memory import _global_memory -from ra_aid.console.formatting import print_error, print_interrupt +from ra_aid.console.formatting import print_error from .memory import get_memory_value, get_related_files, get_work_log, reset_work_log +from .human import ask_human from ..llm import initialize_llm from ..console import print_task_header @@ -61,9 +62,14 @@ def request_research(query: str) -> ResearchResult: console_message=query ) except KeyboardInterrupt: - print_interrupt("Research interrupted by user") - success = False - reason = CANCELLED_BY_USER_REASON + try: + print() + response = ask_human.invoke({"question": "Why did you interrupt me?"}) + success = False + reason = response if response.strip() else CANCELLED_BY_USER_REASON + except Exception: + success = False + reason = CANCELLED_BY_USER_REASON except Exception as e: print_error(f"Error during research: {str(e)}") success = False @@ -118,9 +124,14 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]: success = True reason = None except KeyboardInterrupt: - print_interrupt("Task interrupted by user") - success = False - reason = CANCELLED_BY_USER_REASON + try: + print() + response = ask_human.invoke({"question": "Why did you interrupt me?"}) + success = False + reason = response if response.strip() else CANCELLED_BY_USER_REASON + except Exception: + success = False + reason = CANCELLED_BY_USER_REASON except Exception as e: console.print(f"\n[red]Error during research: {str(e)}[/red]") success = False @@ -184,9 +195,14 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]: success = True reason = None except KeyboardInterrupt: - print_interrupt("Task implementation interrupted by user") - success = False - reason = CANCELLED_BY_USER_REASON + try: + print() + response = ask_human.invoke({"question": "Why did you interrupt me?"}) + success = False + reason = response if response.strip() else CANCELLED_BY_USER_REASON + except Exception: + success = False + reason = CANCELLED_BY_USER_REASON except Exception as e: print_error(f"Error during task implementation: {str(e)}") success = False @@ -240,9 +256,14 @@ def request_implementation(task_spec: str) -> Dict[str, Any]: success = True reason = None except KeyboardInterrupt: - print_interrupt("Planning interrupted by user") - success = False - reason = CANCELLED_BY_USER_REASON + try: + print() + response = ask_human.invoke({"question": "Why did you interrupt me?"}) + success = False + reason = response if response.strip() else CANCELLED_BY_USER_REASON + except Exception: + success = False + reason = CANCELLED_BY_USER_REASON except Exception as e: print_error(f"Error during planning: {str(e)}") success = False