improve agent_should_exit logic

This commit is contained in:
AI Christianson 2025-02-27 10:46:57 -05:00
parent 9403b8c57f
commit 9202cf0d6d
3 changed files with 140 additions and 1 deletions

View File

@ -17,6 +17,9 @@ class AgentContext:
Args: Args:
parent_context: Optional parent context to inherit state from parent_context: Optional parent context to inherit state from
""" """
# Store reference to parent context
self.parent = parent_context
# Initialize completion flags # Initialize completion flags
self.task_completed = False self.task_completed = False
self.plan_completed = False self.plan_completed = False
@ -52,8 +55,15 @@ class AgentContext:
self.completion_message = "" self.completion_message = ""
def mark_should_exit(self) -> None: def mark_should_exit(self) -> None:
"""Mark that the agent should exit execution.""" """Mark that the agent should exit execution.
This propagates the exit state to all parent contexts.
"""
self.agent_should_exit = True self.agent_should_exit = True
# Propagate to parent context if it exists
if self.parent:
self.parent.mark_should_exit()
@property @property
def is_completed(self) -> bool: def is_completed(self) -> bool:

View File

@ -0,0 +1,58 @@
"""Unit tests for agent_should_exit functionality."""
import pytest
from ra_aid.agent_context import (
AgentContext,
agent_context,
get_current_context,
mark_should_exit,
should_exit,
)
class TestAgentShouldExit:
"""Test cases for the agent_should_exit flag and related functions."""
def test_mark_should_exit_basic(self):
"""Test basic mark_should_exit functionality."""
context = AgentContext()
assert context.agent_should_exit is False
context.mark_should_exit()
assert context.agent_should_exit is True
def test_should_exit_utility(self):
"""Test the should_exit utility function."""
with agent_context() as ctx:
assert should_exit() is False
mark_should_exit()
assert should_exit() is True
assert ctx.agent_should_exit is True
def test_propagation_to_parent_context(self):
"""Test that mark_should_exit propagates to parent contexts."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
# Mark child as should exit
child.mark_should_exit()
# Verify both child and parent are marked
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
def test_nested_context_manager_propagation(self):
"""Test propagation with nested context managers."""
with agent_context() as outer:
with agent_context() as inner:
# Initially both should be False
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
# Mark inner as should exit
inner.mark_should_exit()
# Both should now be True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True

View File

@ -13,6 +13,8 @@ from ra_aid.agent_context import (
reset_completion_flags, reset_completion_flags,
is_completed, is_completed,
get_completion_message, get_completion_message,
mark_should_exit,
should_exit,
) )
@ -113,6 +115,60 @@ class TestContextManager:
assert outer.completion_message == "Outer task" assert outer.completion_message == "Outer task"
class TestExitPropagation:
"""Test cases for the agent_should_exit flag propagation."""
def test_mark_should_exit_propagation(self):
"""Test that mark_should_exit propagates to parent contexts."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
# Initially both contexts should have agent_should_exit as False
assert parent.agent_should_exit is False
assert child.agent_should_exit is False
# Mark the child context as should exit
child.mark_should_exit()
# Both child and parent should now have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
def test_nested_should_exit_propagation(self):
"""Test that mark_should_exit propagates through multiple levels of parent contexts."""
grandparent = AgentContext()
parent = AgentContext(parent_context=grandparent)
child = AgentContext(parent_context=parent)
# Initially all contexts should have agent_should_exit as False
assert grandparent.agent_should_exit is False
assert parent.agent_should_exit is False
assert child.agent_should_exit is False
# Mark the child context as should exit
child.mark_should_exit()
# All contexts should now have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
assert grandparent.agent_should_exit is True
def test_context_manager_should_exit_propagation(self):
"""Test that mark_should_exit propagates when using context managers."""
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
# Mark the inner context as should exit
inner.mark_should_exit()
# Both inner and outer should now have agent_should_exit as True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True
class TestThreadIsolation: class TestThreadIsolation:
"""Test thread isolation of context variables.""" """Test thread isolation of context variables."""
@ -176,3 +232,18 @@ class TestUtilityFunctions:
# These should have safe default returns # These should have safe default returns
assert is_completed() is False assert is_completed() is False
assert get_completion_message() == "" assert get_completion_message() == ""
def test_mark_should_exit_utility(self):
"""Test the mark_should_exit utility function."""
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False
assert should_exit() is False
# Mark the current context (inner) as should exit
mark_should_exit()
# Both inner and outer should now have agent_should_exit as True
assert should_exit() is True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True