Allow graceful interruption of agents.

This commit is contained in:
user 2024-12-22 14:00:17 -05:00
parent 6d51a9e1ef
commit f80545d693
6 changed files with 68 additions and 43 deletions

View File

@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
- Chat mode.
- Allow agents to be interrupted.
## [0.7.1] - 2024-12-20

View File

@ -3,6 +3,7 @@ import sys
import uuid
from rich.panel import Panel
from rich.console import Console
from ra_aid.console.formatting import print_interrupt
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from ra_aid.env import validate_environment
@ -221,7 +222,7 @@ def main():
)
except KeyboardInterrupt:
console.print("\n[red]Operation cancelled by user[/red]")
print_interrupt("Operation cancelled by user")
sys.exit(1)
if __name__ == "__main__":

View File

@ -4,8 +4,13 @@ import time
import uuid
from typing import Optional, Any, List
import signal
import threading
import time
from typing import Optional
from langgraph.prebuilt import create_react_agent
from ra_aid.console.formatting import print_stage_header
from ra_aid.console.formatting import print_stage_header, print_error, print_interrupt
from ra_aid.console.output import print_agent_output
from ra_aid.tool_configs import (
get_implementation_tools,
@ -135,10 +140,6 @@ def run_research_agent(
# Run agent with retry logic
return run_agent_with_retry(agent, prompt, run_config)
def print_error(msg: str) -> None:
"""Print error messages."""
console.print(f"\n{msg}", style="red")
def run_planning_agent(
base_task: str,
model,
@ -272,37 +273,55 @@ def run_task_implementation_agent(
# Run agent with retry logic
return run_agent_with_retry(agent, prompt, run_config)
_CONTEXT_STACK = []
_INTERRUPT_CONTEXT = None
def _request_interrupt(signum, frame):
global _INTERRUPT_CONTEXT
if _CONTEXT_STACK:
_INTERRUPT_CONTEXT = _CONTEXT_STACK[-1]
class InterruptibleSection:
def __enter__(self):
_CONTEXT_STACK.append(self)
return self
def __exit__(self, exc_type, exc_value, traceback):
_CONTEXT_STACK.remove(self)
def check_interrupt():
if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]:
raise KeyboardInterrupt("Interrupt requested")
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for internal server errors and task completion handling.
Args:
agent: The agent to run
prompt: The prompt to send to the agent
config: Configuration dictionary for the agent
Returns:
Optional[str]: The completion message if task was completed, None otherwise
Handles API errors with exponential backoff retry logic and checks for task
completion after each chunk of output.
"""
original_handler = None
if threading.current_thread() is threading.main_thread():
original_handler = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, _request_interrupt)
max_retries = 20
base_delay = 1 # Initial delay in seconds
for attempt in range(max_retries):
base_delay = 1
with InterruptibleSection():
try:
for chunk in agent.stream(
{"messages": [HumanMessage(content=prompt)]},
config
):
print_agent_output(chunk)
break
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
if attempt == max_retries - 1:
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
delay = base_delay * (2 ** attempt) # Exponential backoff
error_type = e.__class__.__name__
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
time.sleep(delay)
continue
for attempt in range(max_retries):
check_interrupt()
try:
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
check_interrupt()
print_agent_output(chunk)
return "Agent run completed successfully"
except KeyboardInterrupt:
raise
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
if attempt == max_retries - 1:
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
delay = base_delay * (2 ** attempt)
print_error(f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})")
start = time.monotonic()
while time.monotonic() - start < delay:
check_interrupt()
time.sleep(0.1)
finally:
if original_handler and threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, original_handler)

View File

@ -55,4 +55,5 @@ def print_interrupt(message: str) -> None:
Args:
message: The interruption message to display (supports Markdown formatting)
"""
print() # Give space for "^C"
console.print(Panel(Markdown(message), title="Interrupted", border_style="yellow bold"))

View File

@ -361,8 +361,9 @@ Exit Criteria:
Remember:
- Always begin by calling ask_human.
- Always ask_human before finalizing or exiting.
- Never announce that you are going to ask the human, just do it.
- Never announce that you are going to use a tool, just quietly use it.
- Do communicate results/responses from tools that you call as it pertains to the users request.
- If the user interrupts/cancels an operation, you may want to ask why.
- For deep debugging, logic analysis, or correctness checks, rely on the expert (if expert is available) for guidance.
"""

View File

@ -12,6 +12,8 @@ from .memory import get_memory_value, get_related_files
from ..llm import initialize_llm
from ..console import print_task_header
CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user. This typically is an indication that the action requested was not aligned with the user request."
RESEARCH_AGENT_RECURSION_LIMIT = 2
console = Console()
@ -64,7 +66,7 @@ def request_research(query: str) -> ResearchResult:
except KeyboardInterrupt:
print_interrupt("Research interrupted by user")
success = False
reason = "cancelled_by_user"
reason = CANCELLED_BY_USER_REASON
except Exception as e:
print_error(f"Error during research: {str(e)}")
success = False
@ -116,9 +118,9 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
success = True
reason = None
except KeyboardInterrupt:
console.print("\n[yellow]Research interrupted by user[/yellow]")
print_interrupt("Research interrupted by user")
success = False
reason = "cancelled_by_user"
reason = CANCELLED_BY_USER_REASON
except Exception as e:
console.print(f"\n[red]Error during research: {str(e)}[/red]")
success = False
@ -177,7 +179,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
except KeyboardInterrupt:
print_interrupt("Task implementation interrupted by user")
success = False
reason = "cancelled_by_user"
reason = CANCELLED_BY_USER_REASON
except Exception as e:
print_error(f"Error during task implementation: {str(e)}")
success = False
@ -226,7 +228,7 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
except KeyboardInterrupt:
print_interrupt("Planning interrupted by user")
success = False
reason = "cancelled_by_user"
reason = CANCELLED_BY_USER_REASON
except Exception as e:
print_error(f"Error during planning: {str(e)}")
success = False