From 9202cf0d6d20f78c9fa741de52235643f220e25b Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Thu, 27 Feb 2025 10:46:57 -0500 Subject: [PATCH] improve agent_should_exit logic --- ra_aid/agent_context.py | 12 +++- ra_aid/tests/ra_aid/test_agent_should_exit.py | 58 +++++++++++++++ tests/ra_aid/test_agent_context.py | 71 +++++++++++++++++++ 3 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 ra_aid/tests/ra_aid/test_agent_should_exit.py diff --git a/ra_aid/agent_context.py b/ra_aid/agent_context.py index f52adf0..c8cd01b 100644 --- a/ra_aid/agent_context.py +++ b/ra_aid/agent_context.py @@ -17,6 +17,9 @@ class AgentContext: Args: parent_context: Optional parent context to inherit state from """ + # Store reference to parent context + self.parent = parent_context + # Initialize completion flags self.task_completed = False self.plan_completed = False @@ -52,8 +55,15 @@ class AgentContext: self.completion_message = "" def mark_should_exit(self) -> None: - """Mark that the agent should exit execution.""" + """Mark that the agent should exit execution. + + This propagates the exit state to all parent contexts. + """ self.agent_should_exit = True + + # Propagate to parent context if it exists + if self.parent: + self.parent.mark_should_exit() @property def is_completed(self) -> bool: diff --git a/ra_aid/tests/ra_aid/test_agent_should_exit.py b/ra_aid/tests/ra_aid/test_agent_should_exit.py new file mode 100644 index 0000000..41c8448 --- /dev/null +++ b/ra_aid/tests/ra_aid/test_agent_should_exit.py @@ -0,0 +1,58 @@ +"""Unit tests for agent_should_exit functionality.""" + +import pytest + +from ra_aid.agent_context import ( + AgentContext, + agent_context, + get_current_context, + mark_should_exit, + should_exit, +) + + +class TestAgentShouldExit: + """Test cases for the agent_should_exit flag and related functions.""" + + def test_mark_should_exit_basic(self): + """Test basic mark_should_exit functionality.""" + context = AgentContext() + assert context.agent_should_exit is False + + context.mark_should_exit() + assert context.agent_should_exit is True + + def test_should_exit_utility(self): + """Test the should_exit utility function.""" + with agent_context() as ctx: + assert should_exit() is False + mark_should_exit() + assert should_exit() is True + assert ctx.agent_should_exit is True + + def test_propagation_to_parent_context(self): + """Test that mark_should_exit propagates to parent contexts.""" + parent = AgentContext() + child = AgentContext(parent_context=parent) + + # Mark child as should exit + child.mark_should_exit() + + # Verify both child and parent are marked + assert child.agent_should_exit is True + assert parent.agent_should_exit is True + + def test_nested_context_manager_propagation(self): + """Test propagation with nested context managers.""" + with agent_context() as outer: + with agent_context() as inner: + # Initially both should be False + assert outer.agent_should_exit is False + assert inner.agent_should_exit is False + + # Mark inner as should exit + inner.mark_should_exit() + + # Both should now be True + assert inner.agent_should_exit is True + assert outer.agent_should_exit is True \ No newline at end of file diff --git a/tests/ra_aid/test_agent_context.py b/tests/ra_aid/test_agent_context.py index 52198b0..07bef8e 100644 --- a/tests/ra_aid/test_agent_context.py +++ b/tests/ra_aid/test_agent_context.py @@ -13,6 +13,8 @@ from ra_aid.agent_context import ( reset_completion_flags, is_completed, get_completion_message, + mark_should_exit, + should_exit, ) @@ -113,6 +115,60 @@ class TestContextManager: assert outer.completion_message == "Outer task" +class TestExitPropagation: + """Test cases for the agent_should_exit flag propagation.""" + + def test_mark_should_exit_propagation(self): + """Test that mark_should_exit propagates to parent contexts.""" + parent = AgentContext() + child = AgentContext(parent_context=parent) + + # Initially both contexts should have agent_should_exit as False + assert parent.agent_should_exit is False + assert child.agent_should_exit is False + + # Mark the child context as should exit + child.mark_should_exit() + + # Both child and parent should now have agent_should_exit as True + assert child.agent_should_exit is True + assert parent.agent_should_exit is True + + def test_nested_should_exit_propagation(self): + """Test that mark_should_exit propagates through multiple levels of parent contexts.""" + grandparent = AgentContext() + parent = AgentContext(parent_context=grandparent) + child = AgentContext(parent_context=parent) + + # Initially all contexts should have agent_should_exit as False + assert grandparent.agent_should_exit is False + assert parent.agent_should_exit is False + assert child.agent_should_exit is False + + # Mark the child context as should exit + child.mark_should_exit() + + # All contexts should now have agent_should_exit as True + assert child.agent_should_exit is True + assert parent.agent_should_exit is True + assert grandparent.agent_should_exit is True + + def test_context_manager_should_exit_propagation(self): + """Test that mark_should_exit propagates when using context managers.""" + with agent_context() as outer: + with agent_context() as inner: + # Initially both contexts should have agent_should_exit as False + assert outer.agent_should_exit is False + assert inner.agent_should_exit is False + + # Mark the inner context as should exit + inner.mark_should_exit() + + # Both inner and outer should now have agent_should_exit as True + assert inner.agent_should_exit is True + assert outer.agent_should_exit is True + + class TestThreadIsolation: """Test thread isolation of context variables.""" @@ -176,3 +232,18 @@ class TestUtilityFunctions: # These should have safe default returns assert is_completed() is False assert get_completion_message() == "" + + def test_mark_should_exit_utility(self): + """Test the mark_should_exit utility function.""" + with agent_context() as outer: + with agent_context() as inner: + # Initially both contexts should have agent_should_exit as False + assert should_exit() is False + + # Mark the current context (inner) as should exit + mark_should_exit() + + # Both inner and outer should now have agent_should_exit as True + assert should_exit() is True + assert inner.agent_should_exit is True + assert outer.agent_should_exit is True \ No newline at end of file