From f80545d693ae324ca35d1cecdd11e2e93d602cea Mon Sep 17 00:00:00 2001 From: user Date: Sun, 22 Dec 2024 14:00:17 -0500 Subject: [PATCH] Allow graceful interruption of agents. --- CHANGELOG.md | 1 + ra_aid/__main__.py | 3 +- ra_aid/agent_utils.py | 91 ++++++++++++++++++++++-------------- ra_aid/console/formatting.py | 1 + ra_aid/prompts.py | 3 +- ra_aid/tools/agent.py | 12 +++-- 6 files changed, 68 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 56411d0..5491a0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - Chat mode. +- Allow agents to be interrupted. ## [0.7.1] - 2024-12-20 diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 20fbd5e..96170b4 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -3,6 +3,7 @@ import sys import uuid from rich.panel import Panel from rich.console import Console +from ra_aid.console.formatting import print_interrupt from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent from ra_aid.env import validate_environment @@ -221,7 +222,7 @@ def main(): ) except KeyboardInterrupt: - console.print("\n[red]Operation cancelled by user[/red]") + print_interrupt("Operation cancelled by user") sys.exit(1) if __name__ == "__main__": diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index d74b0ed..f609454 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -4,8 +4,13 @@ import time import uuid from typing import Optional, Any, List +import signal +import threading +import time +from typing import Optional + from langgraph.prebuilt import create_react_agent -from ra_aid.console.formatting import print_stage_header +from ra_aid.console.formatting import print_stage_header, print_error, print_interrupt from ra_aid.console.output import print_agent_output from ra_aid.tool_configs import ( get_implementation_tools, @@ -135,10 +140,6 @@ def run_research_agent( # Run agent with retry logic return run_agent_with_retry(agent, prompt, run_config) -def print_error(msg: str) -> None: - """Print error messages.""" - console.print(f"\n{msg}", style="red") - def run_planning_agent( base_task: str, model, @@ -272,37 +273,55 @@ def run_task_implementation_agent( # Run agent with retry logic return run_agent_with_retry(agent, prompt, run_config) +_CONTEXT_STACK = [] +_INTERRUPT_CONTEXT = None + +def _request_interrupt(signum, frame): + global _INTERRUPT_CONTEXT + if _CONTEXT_STACK: + _INTERRUPT_CONTEXT = _CONTEXT_STACK[-1] + +class InterruptibleSection: + def __enter__(self): + _CONTEXT_STACK.append(self) + return self + + def __exit__(self, exc_type, exc_value, traceback): + _CONTEXT_STACK.remove(self) + +def check_interrupt(): + if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]: + raise KeyboardInterrupt("Interrupt requested") + def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: - """Run an agent with retry logic for internal server errors and task completion handling. - - Args: - agent: The agent to run - prompt: The prompt to send to the agent - config: Configuration dictionary for the agent - - Returns: - Optional[str]: The completion message if task was completed, None otherwise - - Handles API errors with exponential backoff retry logic and checks for task - completion after each chunk of output. - """ + original_handler = None + if threading.current_thread() is threading.main_thread(): + original_handler = signal.getsignal(signal.SIGINT) + signal.signal(signal.SIGINT, _request_interrupt) + max_retries = 20 - base_delay = 1 # Initial delay in seconds - - for attempt in range(max_retries): + base_delay = 1 + + with InterruptibleSection(): try: - for chunk in agent.stream( - {"messages": [HumanMessage(content=prompt)]}, - config - ): - print_agent_output(chunk) - break - except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e: - if attempt == max_retries - 1: - raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}") - - delay = base_delay * (2 ** attempt) # Exponential backoff - error_type = e.__class__.__name__ - print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})") - time.sleep(delay) - continue + for attempt in range(max_retries): + check_interrupt() + try: + for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config): + check_interrupt() + print_agent_output(chunk) + return "Agent run completed successfully" + except KeyboardInterrupt: + raise + except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e: + if attempt == max_retries - 1: + raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}") + delay = base_delay * (2 ** attempt) + print_error(f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})") + start = time.monotonic() + while time.monotonic() - start < delay: + check_interrupt() + time.sleep(0.1) + finally: + if original_handler and threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGINT, original_handler) diff --git a/ra_aid/console/formatting.py b/ra_aid/console/formatting.py index 444c2b8..9728f69 100644 --- a/ra_aid/console/formatting.py +++ b/ra_aid/console/formatting.py @@ -55,4 +55,5 @@ def print_interrupt(message: str) -> None: Args: message: The interruption message to display (supports Markdown formatting) """ + print() # Give space for "^C" console.print(Panel(Markdown(message), title="Interrupted", border_style="yellow bold")) diff --git a/ra_aid/prompts.py b/ra_aid/prompts.py index 601bc4f..baff49a 100644 --- a/ra_aid/prompts.py +++ b/ra_aid/prompts.py @@ -361,8 +361,9 @@ Exit Criteria: Remember: - Always begin by calling ask_human. - Always ask_human before finalizing or exiting. - - Never announce that you are going to ask the human, just do it. + - Never announce that you are going to use a tool, just quietly use it. - Do communicate results/responses from tools that you call as it pertains to the users request. + - If the user interrupts/cancels an operation, you may want to ask why. - For deep debugging, logic analysis, or correctness checks, rely on the expert (if expert is available) for guidance. """ \ No newline at end of file diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 6e395c2..072ae41 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -12,6 +12,8 @@ from .memory import get_memory_value, get_related_files from ..llm import initialize_llm from ..console import print_task_header +CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user. This typically is an indication that the action requested was not aligned with the user request." + RESEARCH_AGENT_RECURSION_LIMIT = 2 console = Console() @@ -64,7 +66,7 @@ def request_research(query: str) -> ResearchResult: except KeyboardInterrupt: print_interrupt("Research interrupted by user") success = False - reason = "cancelled_by_user" + reason = CANCELLED_BY_USER_REASON except Exception as e: print_error(f"Error during research: {str(e)}") success = False @@ -116,9 +118,9 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]: success = True reason = None except KeyboardInterrupt: - console.print("\n[yellow]Research interrupted by user[/yellow]") + print_interrupt("Research interrupted by user") success = False - reason = "cancelled_by_user" + reason = CANCELLED_BY_USER_REASON except Exception as e: console.print(f"\n[red]Error during research: {str(e)}[/red]") success = False @@ -177,7 +179,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]: except KeyboardInterrupt: print_interrupt("Task implementation interrupted by user") success = False - reason = "cancelled_by_user" + reason = CANCELLED_BY_USER_REASON except Exception as e: print_error(f"Error during task implementation: {str(e)}") success = False @@ -226,7 +228,7 @@ def request_implementation(task_spec: str) -> Dict[str, Any]: except KeyboardInterrupt: print_interrupt("Planning interrupted by user") success = False - reason = "cancelled_by_user" + reason = CANCELLED_BY_USER_REASON except Exception as e: print_error(f"Error during planning: {str(e)}") success = False