Allow graceful interruption of agents.

This commit is contained in:
user 2024-12-22 14:00:17 -05:00
parent 6d51a9e1ef
commit f80545d693
6 changed files with 68 additions and 43 deletions

View File

@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
- Chat mode. - Chat mode.
- Allow agents to be interrupted.
## [0.7.1] - 2024-12-20 ## [0.7.1] - 2024-12-20

View File

@ -3,6 +3,7 @@ import sys
import uuid import uuid
from rich.panel import Panel from rich.panel import Panel
from rich.console import Console from rich.console import Console
from ra_aid.console.formatting import print_interrupt
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from ra_aid.env import validate_environment from ra_aid.env import validate_environment
@ -221,7 +222,7 @@ def main():
) )
except KeyboardInterrupt: except KeyboardInterrupt:
console.print("\n[red]Operation cancelled by user[/red]") print_interrupt("Operation cancelled by user")
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -4,8 +4,13 @@ import time
import uuid import uuid
from typing import Optional, Any, List from typing import Optional, Any, List
import signal
import threading
import time
from typing import Optional
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from ra_aid.console.formatting import print_stage_header from ra_aid.console.formatting import print_stage_header, print_error, print_interrupt
from ra_aid.console.output import print_agent_output from ra_aid.console.output import print_agent_output
from ra_aid.tool_configs import ( from ra_aid.tool_configs import (
get_implementation_tools, get_implementation_tools,
@ -135,10 +140,6 @@ def run_research_agent(
# Run agent with retry logic # Run agent with retry logic
return run_agent_with_retry(agent, prompt, run_config) return run_agent_with_retry(agent, prompt, run_config)
def print_error(msg: str) -> None:
"""Print error messages."""
console.print(f"\n{msg}", style="red")
def run_planning_agent( def run_planning_agent(
base_task: str, base_task: str,
model, model,
@ -272,37 +273,55 @@ def run_task_implementation_agent(
# Run agent with retry logic # Run agent with retry logic
return run_agent_with_retry(agent, prompt, run_config) return run_agent_with_retry(agent, prompt, run_config)
_CONTEXT_STACK = []
_INTERRUPT_CONTEXT = None
def _request_interrupt(signum, frame):
global _INTERRUPT_CONTEXT
if _CONTEXT_STACK:
_INTERRUPT_CONTEXT = _CONTEXT_STACK[-1]
class InterruptibleSection:
def __enter__(self):
_CONTEXT_STACK.append(self)
return self
def __exit__(self, exc_type, exc_value, traceback):
_CONTEXT_STACK.remove(self)
def check_interrupt():
if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]:
raise KeyboardInterrupt("Interrupt requested")
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for internal server errors and task completion handling. original_handler = None
if threading.current_thread() is threading.main_thread():
Args: original_handler = signal.getsignal(signal.SIGINT)
agent: The agent to run signal.signal(signal.SIGINT, _request_interrupt)
prompt: The prompt to send to the agent
config: Configuration dictionary for the agent
Returns:
Optional[str]: The completion message if task was completed, None otherwise
Handles API errors with exponential backoff retry logic and checks for task
completion after each chunk of output.
"""
max_retries = 20 max_retries = 20
base_delay = 1 # Initial delay in seconds base_delay = 1
for attempt in range(max_retries): with InterruptibleSection():
try: try:
for chunk in agent.stream( for attempt in range(max_retries):
{"messages": [HumanMessage(content=prompt)]}, check_interrupt()
config try:
): for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
print_agent_output(chunk) check_interrupt()
break print_agent_output(chunk)
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e: return "Agent run completed successfully"
if attempt == max_retries - 1: except KeyboardInterrupt:
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}") raise
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
delay = base_delay * (2 ** attempt) # Exponential backoff if attempt == max_retries - 1:
error_type = e.__class__.__name__ raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})") delay = base_delay * (2 ** attempt)
time.sleep(delay) print_error(f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})")
continue start = time.monotonic()
while time.monotonic() - start < delay:
check_interrupt()
time.sleep(0.1)
finally:
if original_handler and threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, original_handler)

View File

@ -55,4 +55,5 @@ def print_interrupt(message: str) -> None:
Args: Args:
message: The interruption message to display (supports Markdown formatting) message: The interruption message to display (supports Markdown formatting)
""" """
print() # Give space for "^C"
console.print(Panel(Markdown(message), title="Interrupted", border_style="yellow bold")) console.print(Panel(Markdown(message), title="Interrupted", border_style="yellow bold"))

View File

@ -361,8 +361,9 @@ Exit Criteria:
Remember: Remember:
- Always begin by calling ask_human. - Always begin by calling ask_human.
- Always ask_human before finalizing or exiting. - Always ask_human before finalizing or exiting.
- Never announce that you are going to ask the human, just do it. - Never announce that you are going to use a tool, just quietly use it.
- Do communicate results/responses from tools that you call as it pertains to the users request. - Do communicate results/responses from tools that you call as it pertains to the users request.
- If the user interrupts/cancels an operation, you may want to ask why.
- For deep debugging, logic analysis, or correctness checks, rely on the expert (if expert is available) for guidance. - For deep debugging, logic analysis, or correctness checks, rely on the expert (if expert is available) for guidance.
""" """

View File

@ -12,6 +12,8 @@ from .memory import get_memory_value, get_related_files
from ..llm import initialize_llm from ..llm import initialize_llm
from ..console import print_task_header from ..console import print_task_header
CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user. This typically is an indication that the action requested was not aligned with the user request."
RESEARCH_AGENT_RECURSION_LIMIT = 2 RESEARCH_AGENT_RECURSION_LIMIT = 2
console = Console() console = Console()
@ -64,7 +66,7 @@ def request_research(query: str) -> ResearchResult:
except KeyboardInterrupt: except KeyboardInterrupt:
print_interrupt("Research interrupted by user") print_interrupt("Research interrupted by user")
success = False success = False
reason = "cancelled_by_user" reason = CANCELLED_BY_USER_REASON
except Exception as e: except Exception as e:
print_error(f"Error during research: {str(e)}") print_error(f"Error during research: {str(e)}")
success = False success = False
@ -116,9 +118,9 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
success = True success = True
reason = None reason = None
except KeyboardInterrupt: except KeyboardInterrupt:
console.print("\n[yellow]Research interrupted by user[/yellow]") print_interrupt("Research interrupted by user")
success = False success = False
reason = "cancelled_by_user" reason = CANCELLED_BY_USER_REASON
except Exception as e: except Exception as e:
console.print(f"\n[red]Error during research: {str(e)}[/red]") console.print(f"\n[red]Error during research: {str(e)}[/red]")
success = False success = False
@ -177,7 +179,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
except KeyboardInterrupt: except KeyboardInterrupt:
print_interrupt("Task implementation interrupted by user") print_interrupt("Task implementation interrupted by user")
success = False success = False
reason = "cancelled_by_user" reason = CANCELLED_BY_USER_REASON
except Exception as e: except Exception as e:
print_error(f"Error during task implementation: {str(e)}") print_error(f"Error during task implementation: {str(e)}")
success = False success = False
@ -226,7 +228,7 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
except KeyboardInterrupt: except KeyboardInterrupt:
print_interrupt("Planning interrupted by user") print_interrupt("Planning interrupted by user")
success = False success = False
reason = "cancelled_by_user" reason = CANCELLED_BY_USER_REASON
except Exception as e: except Exception as e:
print_error(f"Error during planning: {str(e)}") print_error(f"Error during planning: {str(e)}")
success = False success = False