improve agent_should_exit logic
This commit is contained in:
parent
9403b8c57f
commit
9202cf0d6d
|
|
@ -17,6 +17,9 @@ class AgentContext:
|
||||||
Args:
|
Args:
|
||||||
parent_context: Optional parent context to inherit state from
|
parent_context: Optional parent context to inherit state from
|
||||||
"""
|
"""
|
||||||
|
# Store reference to parent context
|
||||||
|
self.parent = parent_context
|
||||||
|
|
||||||
# Initialize completion flags
|
# Initialize completion flags
|
||||||
self.task_completed = False
|
self.task_completed = False
|
||||||
self.plan_completed = False
|
self.plan_completed = False
|
||||||
|
|
@ -52,9 +55,16 @@ class AgentContext:
|
||||||
self.completion_message = ""
|
self.completion_message = ""
|
||||||
|
|
||||||
def mark_should_exit(self) -> None:
|
def mark_should_exit(self) -> None:
|
||||||
"""Mark that the agent should exit execution."""
|
"""Mark that the agent should exit execution.
|
||||||
|
|
||||||
|
This propagates the exit state to all parent contexts.
|
||||||
|
"""
|
||||||
self.agent_should_exit = True
|
self.agent_should_exit = True
|
||||||
|
|
||||||
|
# Propagate to parent context if it exists
|
||||||
|
if self.parent:
|
||||||
|
self.parent.mark_should_exit()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_completed(self) -> bool:
|
def is_completed(self) -> bool:
|
||||||
"""Check if the current context is marked as completed."""
|
"""Check if the current context is marked as completed."""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
"""Unit tests for agent_should_exit functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ra_aid.agent_context import (
|
||||||
|
AgentContext,
|
||||||
|
agent_context,
|
||||||
|
get_current_context,
|
||||||
|
mark_should_exit,
|
||||||
|
should_exit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentShouldExit:
|
||||||
|
"""Test cases for the agent_should_exit flag and related functions."""
|
||||||
|
|
||||||
|
def test_mark_should_exit_basic(self):
|
||||||
|
"""Test basic mark_should_exit functionality."""
|
||||||
|
context = AgentContext()
|
||||||
|
assert context.agent_should_exit is False
|
||||||
|
|
||||||
|
context.mark_should_exit()
|
||||||
|
assert context.agent_should_exit is True
|
||||||
|
|
||||||
|
def test_should_exit_utility(self):
|
||||||
|
"""Test the should_exit utility function."""
|
||||||
|
with agent_context() as ctx:
|
||||||
|
assert should_exit() is False
|
||||||
|
mark_should_exit()
|
||||||
|
assert should_exit() is True
|
||||||
|
assert ctx.agent_should_exit is True
|
||||||
|
|
||||||
|
def test_propagation_to_parent_context(self):
|
||||||
|
"""Test that mark_should_exit propagates to parent contexts."""
|
||||||
|
parent = AgentContext()
|
||||||
|
child = AgentContext(parent_context=parent)
|
||||||
|
|
||||||
|
# Mark child as should exit
|
||||||
|
child.mark_should_exit()
|
||||||
|
|
||||||
|
# Verify both child and parent are marked
|
||||||
|
assert child.agent_should_exit is True
|
||||||
|
assert parent.agent_should_exit is True
|
||||||
|
|
||||||
|
def test_nested_context_manager_propagation(self):
|
||||||
|
"""Test propagation with nested context managers."""
|
||||||
|
with agent_context() as outer:
|
||||||
|
with agent_context() as inner:
|
||||||
|
# Initially both should be False
|
||||||
|
assert outer.agent_should_exit is False
|
||||||
|
assert inner.agent_should_exit is False
|
||||||
|
|
||||||
|
# Mark inner as should exit
|
||||||
|
inner.mark_should_exit()
|
||||||
|
|
||||||
|
# Both should now be True
|
||||||
|
assert inner.agent_should_exit is True
|
||||||
|
assert outer.agent_should_exit is True
|
||||||
|
|
@ -13,6 +13,8 @@ from ra_aid.agent_context import (
|
||||||
reset_completion_flags,
|
reset_completion_flags,
|
||||||
is_completed,
|
is_completed,
|
||||||
get_completion_message,
|
get_completion_message,
|
||||||
|
mark_should_exit,
|
||||||
|
should_exit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -113,6 +115,60 @@ class TestContextManager:
|
||||||
assert outer.completion_message == "Outer task"
|
assert outer.completion_message == "Outer task"
|
||||||
|
|
||||||
|
|
||||||
|
class TestExitPropagation:
|
||||||
|
"""Test cases for the agent_should_exit flag propagation."""
|
||||||
|
|
||||||
|
def test_mark_should_exit_propagation(self):
|
||||||
|
"""Test that mark_should_exit propagates to parent contexts."""
|
||||||
|
parent = AgentContext()
|
||||||
|
child = AgentContext(parent_context=parent)
|
||||||
|
|
||||||
|
# Initially both contexts should have agent_should_exit as False
|
||||||
|
assert parent.agent_should_exit is False
|
||||||
|
assert child.agent_should_exit is False
|
||||||
|
|
||||||
|
# Mark the child context as should exit
|
||||||
|
child.mark_should_exit()
|
||||||
|
|
||||||
|
# Both child and parent should now have agent_should_exit as True
|
||||||
|
assert child.agent_should_exit is True
|
||||||
|
assert parent.agent_should_exit is True
|
||||||
|
|
||||||
|
def test_nested_should_exit_propagation(self):
|
||||||
|
"""Test that mark_should_exit propagates through multiple levels of parent contexts."""
|
||||||
|
grandparent = AgentContext()
|
||||||
|
parent = AgentContext(parent_context=grandparent)
|
||||||
|
child = AgentContext(parent_context=parent)
|
||||||
|
|
||||||
|
# Initially all contexts should have agent_should_exit as False
|
||||||
|
assert grandparent.agent_should_exit is False
|
||||||
|
assert parent.agent_should_exit is False
|
||||||
|
assert child.agent_should_exit is False
|
||||||
|
|
||||||
|
# Mark the child context as should exit
|
||||||
|
child.mark_should_exit()
|
||||||
|
|
||||||
|
# All contexts should now have agent_should_exit as True
|
||||||
|
assert child.agent_should_exit is True
|
||||||
|
assert parent.agent_should_exit is True
|
||||||
|
assert grandparent.agent_should_exit is True
|
||||||
|
|
||||||
|
def test_context_manager_should_exit_propagation(self):
|
||||||
|
"""Test that mark_should_exit propagates when using context managers."""
|
||||||
|
with agent_context() as outer:
|
||||||
|
with agent_context() as inner:
|
||||||
|
# Initially both contexts should have agent_should_exit as False
|
||||||
|
assert outer.agent_should_exit is False
|
||||||
|
assert inner.agent_should_exit is False
|
||||||
|
|
||||||
|
# Mark the inner context as should exit
|
||||||
|
inner.mark_should_exit()
|
||||||
|
|
||||||
|
# Both inner and outer should now have agent_should_exit as True
|
||||||
|
assert inner.agent_should_exit is True
|
||||||
|
assert outer.agent_should_exit is True
|
||||||
|
|
||||||
|
|
||||||
class TestThreadIsolation:
|
class TestThreadIsolation:
|
||||||
"""Test thread isolation of context variables."""
|
"""Test thread isolation of context variables."""
|
||||||
|
|
||||||
|
|
@ -176,3 +232,18 @@ class TestUtilityFunctions:
|
||||||
# These should have safe default returns
|
# These should have safe default returns
|
||||||
assert is_completed() is False
|
assert is_completed() is False
|
||||||
assert get_completion_message() == ""
|
assert get_completion_message() == ""
|
||||||
|
|
||||||
|
def test_mark_should_exit_utility(self):
|
||||||
|
"""Test the mark_should_exit utility function."""
|
||||||
|
with agent_context() as outer:
|
||||||
|
with agent_context() as inner:
|
||||||
|
# Initially both contexts should have agent_should_exit as False
|
||||||
|
assert should_exit() is False
|
||||||
|
|
||||||
|
# Mark the current context (inner) as should exit
|
||||||
|
mark_should_exit()
|
||||||
|
|
||||||
|
# Both inner and outer should now have agent_should_exit as True
|
||||||
|
assert should_exit() is True
|
||||||
|
assert inner.agent_should_exit is True
|
||||||
|
assert outer.agent_should_exit is True
|
||||||
Loading…
Reference in New Issue