Allow graceful interruption of agents.
This commit is contained in:
parent
6d51a9e1ef
commit
f80545d693
|
|
@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
## [Unreleased]
|
||||
|
||||
- Chat mode.
|
||||
- Allow agents to be interrupted.
|
||||
|
||||
## [0.7.1] - 2024-12-20
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import sys
|
|||
import uuid
|
||||
from rich.panel import Panel
|
||||
from rich.console import Console
|
||||
from ra_aid.console.formatting import print_interrupt
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from ra_aid.env import validate_environment
|
||||
|
|
@ -221,7 +222,7 @@ def main():
|
|||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[red]Operation cancelled by user[/red]")
|
||||
print_interrupt("Operation cancelled by user")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -4,8 +4,13 @@ import time
|
|||
import uuid
|
||||
from typing import Optional, Any, List
|
||||
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from ra_aid.console.formatting import print_stage_header
|
||||
from ra_aid.console.formatting import print_stage_header, print_error, print_interrupt
|
||||
from ra_aid.console.output import print_agent_output
|
||||
from ra_aid.tool_configs import (
|
||||
get_implementation_tools,
|
||||
|
|
@ -135,10 +140,6 @@ def run_research_agent(
|
|||
# Run agent with retry logic
|
||||
return run_agent_with_retry(agent, prompt, run_config)
|
||||
|
||||
def print_error(msg: str) -> None:
|
||||
"""Print error messages."""
|
||||
console.print(f"\n{msg}", style="red")
|
||||
|
||||
def run_planning_agent(
|
||||
base_task: str,
|
||||
model,
|
||||
|
|
@ -272,37 +273,55 @@ def run_task_implementation_agent(
|
|||
# Run agent with retry logic
|
||||
return run_agent_with_retry(agent, prompt, run_config)
|
||||
|
||||
_CONTEXT_STACK = []
|
||||
_INTERRUPT_CONTEXT = None
|
||||
|
||||
def _request_interrupt(signum, frame):
|
||||
global _INTERRUPT_CONTEXT
|
||||
if _CONTEXT_STACK:
|
||||
_INTERRUPT_CONTEXT = _CONTEXT_STACK[-1]
|
||||
|
||||
class InterruptibleSection:
|
||||
def __enter__(self):
|
||||
_CONTEXT_STACK.append(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
_CONTEXT_STACK.remove(self)
|
||||
|
||||
def check_interrupt():
|
||||
if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]:
|
||||
raise KeyboardInterrupt("Interrupt requested")
|
||||
|
||||
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||
"""Run an agent with retry logic for internal server errors and task completion handling.
|
||||
|
||||
Args:
|
||||
agent: The agent to run
|
||||
prompt: The prompt to send to the agent
|
||||
config: Configuration dictionary for the agent
|
||||
|
||||
Returns:
|
||||
Optional[str]: The completion message if task was completed, None otherwise
|
||||
|
||||
Handles API errors with exponential backoff retry logic and checks for task
|
||||
completion after each chunk of output.
|
||||
"""
|
||||
original_handler = None
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
signal.signal(signal.SIGINT, _request_interrupt)
|
||||
|
||||
max_retries = 20
|
||||
base_delay = 1 # Initial delay in seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
base_delay = 1
|
||||
|
||||
with InterruptibleSection():
|
||||
try:
|
||||
for chunk in agent.stream(
|
||||
{"messages": [HumanMessage(content=prompt)]},
|
||||
config
|
||||
):
|
||||
print_agent_output(chunk)
|
||||
break
|
||||
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
|
||||
|
||||
delay = base_delay * (2 ** attempt) # Exponential backoff
|
||||
error_type = e.__class__.__name__
|
||||
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
|
||||
time.sleep(delay)
|
||||
continue
|
||||
for attempt in range(max_retries):
|
||||
check_interrupt()
|
||||
try:
|
||||
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
return "Agent run completed successfully"
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
|
||||
delay = base_delay * (2 ** attempt)
|
||||
print_error(f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})")
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < delay:
|
||||
check_interrupt()
|
||||
time.sleep(0.1)
|
||||
finally:
|
||||
if original_handler and threading.current_thread() is threading.main_thread():
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
|
|
|
|||
|
|
@ -55,4 +55,5 @@ def print_interrupt(message: str) -> None:
|
|||
Args:
|
||||
message: The interruption message to display (supports Markdown formatting)
|
||||
"""
|
||||
print() # Give space for "^C"
|
||||
console.print(Panel(Markdown(message), title="Interrupted", border_style="yellow bold"))
|
||||
|
|
|
|||
|
|
@ -361,8 +361,9 @@ Exit Criteria:
|
|||
Remember:
|
||||
- Always begin by calling ask_human.
|
||||
- Always ask_human before finalizing or exiting.
|
||||
- Never announce that you are going to ask the human, just do it.
|
||||
- Never announce that you are going to use a tool, just quietly use it.
|
||||
- Do communicate results/responses from tools that you call as it pertains to the users request.
|
||||
- If the user interrupts/cancels an operation, you may want to ask why.
|
||||
- For deep debugging, logic analysis, or correctness checks, rely on the expert (if expert is available) for guidance.
|
||||
|
||||
"""
|
||||
|
|
@ -12,6 +12,8 @@ from .memory import get_memory_value, get_related_files
|
|||
from ..llm import initialize_llm
|
||||
from ..console import print_task_header
|
||||
|
||||
CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user. This typically is an indication that the action requested was not aligned with the user request."
|
||||
|
||||
RESEARCH_AGENT_RECURSION_LIMIT = 2
|
||||
|
||||
console = Console()
|
||||
|
|
@ -64,7 +66,7 @@ def request_research(query: str) -> ResearchResult:
|
|||
except KeyboardInterrupt:
|
||||
print_interrupt("Research interrupted by user")
|
||||
success = False
|
||||
reason = "cancelled_by_user"
|
||||
reason = CANCELLED_BY_USER_REASON
|
||||
except Exception as e:
|
||||
print_error(f"Error during research: {str(e)}")
|
||||
success = False
|
||||
|
|
@ -116,9 +118,9 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
|||
success = True
|
||||
reason = None
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]Research interrupted by user[/yellow]")
|
||||
print_interrupt("Research interrupted by user")
|
||||
success = False
|
||||
reason = "cancelled_by_user"
|
||||
reason = CANCELLED_BY_USER_REASON
|
||||
except Exception as e:
|
||||
console.print(f"\n[red]Error during research: {str(e)}[/red]")
|
||||
success = False
|
||||
|
|
@ -177,7 +179,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
|
|||
except KeyboardInterrupt:
|
||||
print_interrupt("Task implementation interrupted by user")
|
||||
success = False
|
||||
reason = "cancelled_by_user"
|
||||
reason = CANCELLED_BY_USER_REASON
|
||||
except Exception as e:
|
||||
print_error(f"Error during task implementation: {str(e)}")
|
||||
success = False
|
||||
|
|
@ -226,7 +228,7 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
|
|||
except KeyboardInterrupt:
|
||||
print_interrupt("Planning interrupted by user")
|
||||
success = False
|
||||
reason = "cancelled_by_user"
|
||||
reason = CANCELLED_BY_USER_REASON
|
||||
except Exception as e:
|
||||
print_error(f"Error during planning: {str(e)}")
|
||||
success = False
|
||||
|
|
|
|||
Loading…
Reference in New Issue