improve agent_should_exit logic
This commit is contained in:
parent
9403b8c57f
commit
9202cf0d6d
|
|
@ -17,6 +17,9 @@ class AgentContext:
|
|||
Args:
|
||||
parent_context: Optional parent context to inherit state from
|
||||
"""
|
||||
# Store reference to parent context
|
||||
self.parent = parent_context
|
||||
|
||||
# Initialize completion flags
|
||||
self.task_completed = False
|
||||
self.plan_completed = False
|
||||
|
|
@ -52,8 +55,15 @@ class AgentContext:
|
|||
self.completion_message = ""
|
||||
|
||||
def mark_should_exit(self) -> None:
|
||||
"""Mark that the agent should exit execution."""
|
||||
"""Mark that the agent should exit execution.
|
||||
|
||||
This propagates the exit state to all parent contexts.
|
||||
"""
|
||||
self.agent_should_exit = True
|
||||
|
||||
# Propagate to parent context if it exists
|
||||
if self.parent:
|
||||
self.parent.mark_should_exit()
|
||||
|
||||
@property
|
||||
def is_completed(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
"""Unit tests for agent_should_exit functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
AgentContext,
|
||||
agent_context,
|
||||
get_current_context,
|
||||
mark_should_exit,
|
||||
should_exit,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentShouldExit:
|
||||
"""Test cases for the agent_should_exit flag and related functions."""
|
||||
|
||||
def test_mark_should_exit_basic(self):
|
||||
"""Test basic mark_should_exit functionality."""
|
||||
context = AgentContext()
|
||||
assert context.agent_should_exit is False
|
||||
|
||||
context.mark_should_exit()
|
||||
assert context.agent_should_exit is True
|
||||
|
||||
def test_should_exit_utility(self):
|
||||
"""Test the should_exit utility function."""
|
||||
with agent_context() as ctx:
|
||||
assert should_exit() is False
|
||||
mark_should_exit()
|
||||
assert should_exit() is True
|
||||
assert ctx.agent_should_exit is True
|
||||
|
||||
def test_propagation_to_parent_context(self):
|
||||
"""Test that mark_should_exit propagates to parent contexts."""
|
||||
parent = AgentContext()
|
||||
child = AgentContext(parent_context=parent)
|
||||
|
||||
# Mark child as should exit
|
||||
child.mark_should_exit()
|
||||
|
||||
# Verify both child and parent are marked
|
||||
assert child.agent_should_exit is True
|
||||
assert parent.agent_should_exit is True
|
||||
|
||||
def test_nested_context_manager_propagation(self):
|
||||
"""Test propagation with nested context managers."""
|
||||
with agent_context() as outer:
|
||||
with agent_context() as inner:
|
||||
# Initially both should be False
|
||||
assert outer.agent_should_exit is False
|
||||
assert inner.agent_should_exit is False
|
||||
|
||||
# Mark inner as should exit
|
||||
inner.mark_should_exit()
|
||||
|
||||
# Both should now be True
|
||||
assert inner.agent_should_exit is True
|
||||
assert outer.agent_should_exit is True
|
||||
|
|
@ -13,6 +13,8 @@ from ra_aid.agent_context import (
|
|||
reset_completion_flags,
|
||||
is_completed,
|
||||
get_completion_message,
|
||||
mark_should_exit,
|
||||
should_exit,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -113,6 +115,60 @@ class TestContextManager:
|
|||
assert outer.completion_message == "Outer task"
|
||||
|
||||
|
||||
class TestExitPropagation:
|
||||
"""Test cases for the agent_should_exit flag propagation."""
|
||||
|
||||
def test_mark_should_exit_propagation(self):
|
||||
"""Test that mark_should_exit propagates to parent contexts."""
|
||||
parent = AgentContext()
|
||||
child = AgentContext(parent_context=parent)
|
||||
|
||||
# Initially both contexts should have agent_should_exit as False
|
||||
assert parent.agent_should_exit is False
|
||||
assert child.agent_should_exit is False
|
||||
|
||||
# Mark the child context as should exit
|
||||
child.mark_should_exit()
|
||||
|
||||
# Both child and parent should now have agent_should_exit as True
|
||||
assert child.agent_should_exit is True
|
||||
assert parent.agent_should_exit is True
|
||||
|
||||
def test_nested_should_exit_propagation(self):
|
||||
"""Test that mark_should_exit propagates through multiple levels of parent contexts."""
|
||||
grandparent = AgentContext()
|
||||
parent = AgentContext(parent_context=grandparent)
|
||||
child = AgentContext(parent_context=parent)
|
||||
|
||||
# Initially all contexts should have agent_should_exit as False
|
||||
assert grandparent.agent_should_exit is False
|
||||
assert parent.agent_should_exit is False
|
||||
assert child.agent_should_exit is False
|
||||
|
||||
# Mark the child context as should exit
|
||||
child.mark_should_exit()
|
||||
|
||||
# All contexts should now have agent_should_exit as True
|
||||
assert child.agent_should_exit is True
|
||||
assert parent.agent_should_exit is True
|
||||
assert grandparent.agent_should_exit is True
|
||||
|
||||
def test_context_manager_should_exit_propagation(self):
|
||||
"""Test that mark_should_exit propagates when using context managers."""
|
||||
with agent_context() as outer:
|
||||
with agent_context() as inner:
|
||||
# Initially both contexts should have agent_should_exit as False
|
||||
assert outer.agent_should_exit is False
|
||||
assert inner.agent_should_exit is False
|
||||
|
||||
# Mark the inner context as should exit
|
||||
inner.mark_should_exit()
|
||||
|
||||
# Both inner and outer should now have agent_should_exit as True
|
||||
assert inner.agent_should_exit is True
|
||||
assert outer.agent_should_exit is True
|
||||
|
||||
|
||||
class TestThreadIsolation:
|
||||
"""Test thread isolation of context variables."""
|
||||
|
||||
|
|
@ -176,3 +232,18 @@ class TestUtilityFunctions:
|
|||
# These should have safe default returns
|
||||
assert is_completed() is False
|
||||
assert get_completion_message() == ""
|
||||
|
||||
def test_mark_should_exit_utility(self):
|
||||
"""Test the mark_should_exit utility function."""
|
||||
with agent_context() as outer:
|
||||
with agent_context() as inner:
|
||||
# Initially both contexts should have agent_should_exit as False
|
||||
assert should_exit() is False
|
||||
|
||||
# Mark the current context (inner) as should exit
|
||||
mark_should_exit()
|
||||
|
||||
# Both inner and outer should now have agent_should_exit as True
|
||||
assert should_exit() is True
|
||||
assert inner.agent_should_exit is True
|
||||
assert outer.agent_should_exit is True
|
||||
Loading…
Reference in New Issue