diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index b7a1dbc..1abc399 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -12,6 +12,7 @@ from ra_aid import print_stage_header, print_error, print_interrupt from ra_aid.tools.agent import CANCELLED_BY_USER_REASON from ra_aid.tools.human import ask_human from ra_aid.agent_utils import ( + AgentInterrupt, run_agent_with_retry, run_research_agent, run_planning_agent @@ -227,7 +228,7 @@ def main(): config=config ) - except KeyboardInterrupt: + except (KeyboardInterrupt, AgentInterrupt): print() print(" 👋 Bye!") print() diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 957691a..be59a5a 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -10,6 +10,14 @@ import threading import time from typing import Optional +class AgentInterrupt(Exception): + """Exception raised when an agent's execution is interrupted. + + This exception is used for internal agent interruption handling, + separate from KeyboardInterrupt which is reserved for top-level handling. + """ + pass + from langgraph.prebuilt import create_react_agent from ra_aid.console.formatting import print_stage_header, print_error from ra_aid.console.output import print_agent_output @@ -52,6 +60,7 @@ from ra_aid.prompts import ( HUMAN_PROMPT_SECTION_RESEARCH ) + console = Console() def run_research_agent( @@ -287,7 +296,9 @@ def _request_interrupt(signum, frame): _INTERRUPT_CONTEXT = _CONTEXT_STACK[-1] if _FEEDBACK_MODE: + print() print(" 👋 Bye!") + print() sys.exit(0) class InterruptibleSection: @@ -300,7 +311,7 @@ class InterruptibleSection: def check_interrupt(): if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]: - raise KeyboardInterrupt("Interrupt requested") + raise AgentInterrupt("Interrupt requested") def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: original_handler = None @@ -324,7 +335,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: check_interrupt() print_agent_output(chunk) return "Agent run completed successfully" - except KeyboardInterrupt: + except (KeyboardInterrupt, AgentInterrupt): raise except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e: if attempt == max_retries - 1: diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 85cf604..c9ee3cd 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -3,6 +3,7 @@ from langchain_core.tools import tool from typing import Dict, Any, Union, List from typing_extensions import TypeAlias +from ..agent_utils import AgentInterrupt ResearchResult = Dict[str, Union[str, bool, Dict[int, Any], List[Any], None]] from rich.console import Console @@ -61,15 +62,13 @@ def request_research(query: str) -> ResearchResult: hil=config.get('hil', False), console_message=query ) + except AgentInterrupt: + print() + response = ask_human.invoke({"question": "Why did you interrupt me?"}) + success = False + reason = response if response.strip() else CANCELLED_BY_USER_REASON except KeyboardInterrupt: - try: - print() - response = ask_human.invoke({"question": "Why did you interrupt me?"}) - success = False - reason = response if response.strip() else CANCELLED_BY_USER_REASON - except Exception: - success = False - reason = CANCELLED_BY_USER_REASON + raise except Exception as e: print_error(f"Error during research: {str(e)}") success = False @@ -123,15 +122,13 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]: success = True reason = None + except AgentInterrupt: + print() + response = ask_human.invoke({"question": "Why did you interrupt me?"}) + success = False + reason = response if response.strip() else CANCELLED_BY_USER_REASON except KeyboardInterrupt: - try: - print() - response = ask_human.invoke({"question": "Why did you interrupt me?"}) - success = False - reason = response if response.strip() else CANCELLED_BY_USER_REASON - except Exception: - success = False - reason = CANCELLED_BY_USER_REASON + raise except Exception as e: console.print(f"\n[red]Error during research: {str(e)}[/red]") success = False @@ -194,15 +191,13 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]: success = True reason = None + except AgentInterrupt: + print() + response = ask_human.invoke({"question": "Why did you interrupt me?"}) + success = False + reason = response if response.strip() else CANCELLED_BY_USER_REASON except KeyboardInterrupt: - try: - print() - response = ask_human.invoke({"question": "Why did you interrupt me?"}) - success = False - reason = response if response.strip() else CANCELLED_BY_USER_REASON - except Exception: - success = False - reason = CANCELLED_BY_USER_REASON + raise except Exception as e: print_error(f"Error during task implementation: {str(e)}") success = False @@ -255,15 +250,13 @@ def request_implementation(task_spec: str) -> Dict[str, Any]: success = True reason = None + except AgentInterrupt: + print() + response = ask_human.invoke({"question": "Why did you interrupt me?"}) + success = False + reason = response if response.strip() else CANCELLED_BY_USER_REASON except KeyboardInterrupt: - try: - print() - response = ask_human.invoke({"question": "Why did you interrupt me?"}) - success = False - reason = response if response.strip() else CANCELLED_BY_USER_REASON - except Exception: - success = False - reason = CANCELLED_BY_USER_REASON + raise except Exception as e: print_error(f"Error during planning: {str(e)}") success = False