From 6c85a39bbb650e8d27e60a2f3478f9b7cd9adc17 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Thu, 27 Feb 2025 10:14:17 -0500 Subject: [PATCH] added agent_should_exit context --- ra_aid/agent_context.py | 23 +++++++++++++++++++++++ ra_aid/agent_utils.py | 3 ++- ra_aid/tools/memory.py | 3 ++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/ra_aid/agent_context.py b/ra_aid/agent_context.py index 2d5ddfa..8148963 100644 --- a/ra_aid/agent_context.py +++ b/ra_aid/agent_context.py @@ -21,12 +21,14 @@ class AgentContext: self.task_completed = False self.plan_completed = False self.completion_message = "" + self.agent_should_exit = False # Inherit state from parent if provided if parent_context: self.task_completed = parent_context.task_completed self.plan_completed = parent_context.plan_completed self.completion_message = parent_context.completion_message + self.agent_should_exit = parent_context.agent_should_exit def mark_task_completed(self, message: str) -> None: """Mark the current task as completed. @@ -53,6 +55,10 @@ class AgentContext: self.plan_completed = False self.completion_message = "" + def mark_should_exit(self) -> None: + """Mark that the agent should exit execution.""" + self.agent_should_exit = True + @property def is_completed(self) -> bool: """Check if the current context is marked as completed.""" @@ -148,3 +154,20 @@ def get_completion_message() -> str: """ context = get_current_context() return context.completion_message if context else "" + + +def should_exit() -> bool: + """Check if the agent should exit execution. + + Returns: + True if the agent should exit, False otherwise + """ + context = get_current_context() + return context.agent_should_exit if context else False + + +def mark_should_exit() -> None: + """Mark that the agent should exit execution.""" + context = get_current_context() + if context: + context.mark_should_exit() \ No newline at end of file diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index e39d429..8ce01dd 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -75,6 +75,7 @@ from ra_aid.agent_context import ( is_completed, reset_completion_flags, get_completion_message, + should_exit, ) from ra_aid.tools.memory import ( _global_memory, @@ -903,7 +904,7 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict) check_interrupt() agent_type = get_agent_type(agent) print_agent_output(chunk, agent_type) - if is_completed(): + if is_completed() or should_exit(): reset_completion_flags() break diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index 21e362a..f2df9ea 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -12,7 +12,7 @@ from rich.markdown import Markdown from rich.panel import Panel from typing_extensions import TypedDict -from ra_aid.agent_context import mark_task_completed, mark_plan_completed +from ra_aid.agent_context import mark_task_completed, mark_plan_completed, mark_should_exit class WorkLogEntry(TypedDict): @@ -339,6 +339,7 @@ def plan_implementation_completed(message: str) -> str: Args: message: Message explaining how the implementation plan was completed """ + mark_should_exit() mark_plan_completed(message) _global_memory["tasks"].clear() # Clear task list when plan is completed _global_memory["task_id_counter"] = 1