added agent_should_exit context

This commit is contained in:
AI Christianson 2025-02-27 10:14:17 -05:00
parent 6d2b0a148d
commit 6c85a39bbb
3 changed files with 27 additions and 2 deletions

View File

@ -21,12 +21,14 @@ class AgentContext:
self.task_completed = False self.task_completed = False
self.plan_completed = False self.plan_completed = False
self.completion_message = "" self.completion_message = ""
self.agent_should_exit = False
# Inherit state from parent if provided # Inherit state from parent if provided
if parent_context: if parent_context:
self.task_completed = parent_context.task_completed self.task_completed = parent_context.task_completed
self.plan_completed = parent_context.plan_completed self.plan_completed = parent_context.plan_completed
self.completion_message = parent_context.completion_message self.completion_message = parent_context.completion_message
self.agent_should_exit = parent_context.agent_should_exit
def mark_task_completed(self, message: str) -> None: def mark_task_completed(self, message: str) -> None:
"""Mark the current task as completed. """Mark the current task as completed.
@ -53,6 +55,10 @@ class AgentContext:
self.plan_completed = False self.plan_completed = False
self.completion_message = "" self.completion_message = ""
def mark_should_exit(self) -> None:
"""Mark that the agent should exit execution."""
self.agent_should_exit = True
@property @property
def is_completed(self) -> bool: def is_completed(self) -> bool:
"""Check if the current context is marked as completed.""" """Check if the current context is marked as completed."""
@ -148,3 +154,20 @@ def get_completion_message() -> str:
""" """
context = get_current_context() context = get_current_context()
return context.completion_message if context else "" return context.completion_message if context else ""
def should_exit() -> bool:
"""Check if the agent should exit execution.
Returns:
True if the agent should exit, False otherwise
"""
context = get_current_context()
return context.agent_should_exit if context else False
def mark_should_exit() -> None:
"""Mark that the agent should exit execution."""
context = get_current_context()
if context:
context.mark_should_exit()

View File

@ -75,6 +75,7 @@ from ra_aid.agent_context import (
is_completed, is_completed,
reset_completion_flags, reset_completion_flags,
get_completion_message, get_completion_message,
should_exit,
) )
from ra_aid.tools.memory import ( from ra_aid.tools.memory import (
_global_memory, _global_memory,
@ -903,7 +904,7 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
check_interrupt() check_interrupt()
agent_type = get_agent_type(agent) agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type) print_agent_output(chunk, agent_type)
if is_completed(): if is_completed() or should_exit():
reset_completion_flags() reset_completion_flags()
break break

View File

@ -12,7 +12,7 @@ from rich.markdown import Markdown
from rich.panel import Panel from rich.panel import Panel
from typing_extensions import TypedDict from typing_extensions import TypedDict
from ra_aid.agent_context import mark_task_completed, mark_plan_completed from ra_aid.agent_context import mark_task_completed, mark_plan_completed, mark_should_exit
class WorkLogEntry(TypedDict): class WorkLogEntry(TypedDict):
@ -339,6 +339,7 @@ def plan_implementation_completed(message: str) -> str:
Args: Args:
message: Message explaining how the implementation plan was completed message: Message explaining how the implementation plan was completed
""" """
mark_should_exit()
mark_plan_completed(message) mark_plan_completed(message)
_global_memory["tasks"].clear() # Clear task list when plan is completed _global_memory["tasks"].clear() # Clear task list when plan is completed
_global_memory["task_id_counter"] = 1 _global_memory["task_id_counter"] = 1