Allow graceful interruption of agents.
This commit is contained in:
parent
6d51a9e1ef
commit
f80545d693
|
|
@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
- Chat mode.
|
- Chat mode.
|
||||||
|
- Allow agents to be interrupted.
|
||||||
|
|
||||||
## [0.7.1] - 2024-12-20
|
## [0.7.1] - 2024-12-20
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import sys
|
||||||
import uuid
|
import uuid
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
from ra_aid.console.formatting import print_interrupt
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from ra_aid.env import validate_environment
|
from ra_aid.env import validate_environment
|
||||||
|
|
@ -221,7 +222,7 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("\n[red]Operation cancelled by user[/red]")
|
print_interrupt("Operation cancelled by user")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,13 @@ import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, Any, List
|
from typing import Optional, Any, List
|
||||||
|
|
||||||
|
import signal
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from ra_aid.console.formatting import print_stage_header
|
from ra_aid.console.formatting import print_stage_header, print_error, print_interrupt
|
||||||
from ra_aid.console.output import print_agent_output
|
from ra_aid.console.output import print_agent_output
|
||||||
from ra_aid.tool_configs import (
|
from ra_aid.tool_configs import (
|
||||||
get_implementation_tools,
|
get_implementation_tools,
|
||||||
|
|
@ -135,10 +140,6 @@ def run_research_agent(
|
||||||
# Run agent with retry logic
|
# Run agent with retry logic
|
||||||
return run_agent_with_retry(agent, prompt, run_config)
|
return run_agent_with_retry(agent, prompt, run_config)
|
||||||
|
|
||||||
def print_error(msg: str) -> None:
|
|
||||||
"""Print error messages."""
|
|
||||||
console.print(f"\n{msg}", style="red")
|
|
||||||
|
|
||||||
def run_planning_agent(
|
def run_planning_agent(
|
||||||
base_task: str,
|
base_task: str,
|
||||||
model,
|
model,
|
||||||
|
|
@ -272,37 +273,55 @@ def run_task_implementation_agent(
|
||||||
# Run agent with retry logic
|
# Run agent with retry logic
|
||||||
return run_agent_with_retry(agent, prompt, run_config)
|
return run_agent_with_retry(agent, prompt, run_config)
|
||||||
|
|
||||||
|
_CONTEXT_STACK = []
|
||||||
|
_INTERRUPT_CONTEXT = None
|
||||||
|
|
||||||
|
def _request_interrupt(signum, frame):
|
||||||
|
global _INTERRUPT_CONTEXT
|
||||||
|
if _CONTEXT_STACK:
|
||||||
|
_INTERRUPT_CONTEXT = _CONTEXT_STACK[-1]
|
||||||
|
|
||||||
|
class InterruptibleSection:
|
||||||
|
def __enter__(self):
|
||||||
|
_CONTEXT_STACK.append(self)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
_CONTEXT_STACK.remove(self)
|
||||||
|
|
||||||
|
def check_interrupt():
|
||||||
|
if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]:
|
||||||
|
raise KeyboardInterrupt("Interrupt requested")
|
||||||
|
|
||||||
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||||
"""Run an agent with retry logic for internal server errors and task completion handling.
|
original_handler = None
|
||||||
|
if threading.current_thread() is threading.main_thread():
|
||||||
Args:
|
original_handler = signal.getsignal(signal.SIGINT)
|
||||||
agent: The agent to run
|
signal.signal(signal.SIGINT, _request_interrupt)
|
||||||
prompt: The prompt to send to the agent
|
|
||||||
config: Configuration dictionary for the agent
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[str]: The completion message if task was completed, None otherwise
|
|
||||||
|
|
||||||
Handles API errors with exponential backoff retry logic and checks for task
|
|
||||||
completion after each chunk of output.
|
|
||||||
"""
|
|
||||||
max_retries = 20
|
max_retries = 20
|
||||||
base_delay = 1 # Initial delay in seconds
|
base_delay = 1
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
with InterruptibleSection():
|
||||||
try:
|
try:
|
||||||
for chunk in agent.stream(
|
for attempt in range(max_retries):
|
||||||
{"messages": [HumanMessage(content=prompt)]},
|
check_interrupt()
|
||||||
config
|
try:
|
||||||
):
|
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
|
||||||
print_agent_output(chunk)
|
check_interrupt()
|
||||||
break
|
print_agent_output(chunk)
|
||||||
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
|
return "Agent run completed successfully"
|
||||||
if attempt == max_retries - 1:
|
except KeyboardInterrupt:
|
||||||
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
|
raise
|
||||||
|
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
|
||||||
delay = base_delay * (2 ** attempt) # Exponential backoff
|
if attempt == max_retries - 1:
|
||||||
error_type = e.__class__.__name__
|
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
|
||||||
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
|
delay = base_delay * (2 ** attempt)
|
||||||
time.sleep(delay)
|
print_error(f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})")
|
||||||
continue
|
start = time.monotonic()
|
||||||
|
while time.monotonic() - start < delay:
|
||||||
|
check_interrupt()
|
||||||
|
time.sleep(0.1)
|
||||||
|
finally:
|
||||||
|
if original_handler and threading.current_thread() is threading.main_thread():
|
||||||
|
signal.signal(signal.SIGINT, original_handler)
|
||||||
|
|
|
||||||
|
|
@ -55,4 +55,5 @@ def print_interrupt(message: str) -> None:
|
||||||
Args:
|
Args:
|
||||||
message: The interruption message to display (supports Markdown formatting)
|
message: The interruption message to display (supports Markdown formatting)
|
||||||
"""
|
"""
|
||||||
|
print() # Give space for "^C"
|
||||||
console.print(Panel(Markdown(message), title="Interrupted", border_style="yellow bold"))
|
console.print(Panel(Markdown(message), title="Interrupted", border_style="yellow bold"))
|
||||||
|
|
|
||||||
|
|
@ -361,8 +361,9 @@ Exit Criteria:
|
||||||
Remember:
|
Remember:
|
||||||
- Always begin by calling ask_human.
|
- Always begin by calling ask_human.
|
||||||
- Always ask_human before finalizing or exiting.
|
- Always ask_human before finalizing or exiting.
|
||||||
- Never announce that you are going to ask the human, just do it.
|
- Never announce that you are going to use a tool, just quietly use it.
|
||||||
- Do communicate results/responses from tools that you call as it pertains to the users request.
|
- Do communicate results/responses from tools that you call as it pertains to the users request.
|
||||||
|
- If the user interrupts/cancels an operation, you may want to ask why.
|
||||||
- For deep debugging, logic analysis, or correctness checks, rely on the expert (if expert is available) for guidance.
|
- For deep debugging, logic analysis, or correctness checks, rely on the expert (if expert is available) for guidance.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
@ -12,6 +12,8 @@ from .memory import get_memory_value, get_related_files
|
||||||
from ..llm import initialize_llm
|
from ..llm import initialize_llm
|
||||||
from ..console import print_task_header
|
from ..console import print_task_header
|
||||||
|
|
||||||
|
CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user. This typically is an indication that the action requested was not aligned with the user request."
|
||||||
|
|
||||||
RESEARCH_AGENT_RECURSION_LIMIT = 2
|
RESEARCH_AGENT_RECURSION_LIMIT = 2
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
@ -64,7 +66,7 @@ def request_research(query: str) -> ResearchResult:
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print_interrupt("Research interrupted by user")
|
print_interrupt("Research interrupted by user")
|
||||||
success = False
|
success = False
|
||||||
reason = "cancelled_by_user"
|
reason = CANCELLED_BY_USER_REASON
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_error(f"Error during research: {str(e)}")
|
print_error(f"Error during research: {str(e)}")
|
||||||
success = False
|
success = False
|
||||||
|
|
@ -116,9 +118,9 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
||||||
success = True
|
success = True
|
||||||
reason = None
|
reason = None
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("\n[yellow]Research interrupted by user[/yellow]")
|
print_interrupt("Research interrupted by user")
|
||||||
success = False
|
success = False
|
||||||
reason = "cancelled_by_user"
|
reason = CANCELLED_BY_USER_REASON
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"\n[red]Error during research: {str(e)}[/red]")
|
console.print(f"\n[red]Error during research: {str(e)}[/red]")
|
||||||
success = False
|
success = False
|
||||||
|
|
@ -177,7 +179,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print_interrupt("Task implementation interrupted by user")
|
print_interrupt("Task implementation interrupted by user")
|
||||||
success = False
|
success = False
|
||||||
reason = "cancelled_by_user"
|
reason = CANCELLED_BY_USER_REASON
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_error(f"Error during task implementation: {str(e)}")
|
print_error(f"Error during task implementation: {str(e)}")
|
||||||
success = False
|
success = False
|
||||||
|
|
@ -226,7 +228,7 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print_interrupt("Planning interrupted by user")
|
print_interrupt("Planning interrupted by user")
|
||||||
success = False
|
success = False
|
||||||
reason = "cancelled_by_user"
|
reason = CANCELLED_BY_USER_REASON
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_error(f"Error during planning: {str(e)}")
|
print_error(f"Error during planning: {str(e)}")
|
||||||
success = False
|
success = False
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue