should exit propagation

This commit is contained in:
AI Christianson 2025-03-07 11:25:46 -05:00
parent 77856bfa0c
commit 9e9c3ad3d2
6 changed files with 256 additions and 38 deletions

View File

@ -56,16 +56,28 @@ class AgentContext:
self.plan_completed = False
self.completion_message = ""
def mark_should_exit(self) -> None:
def mark_should_exit(self, propagation_depth: Optional[int] = 0) -> None:
"""Mark that the agent should exit execution.
This propagates the exit state to all parent contexts.
Args:
propagation_depth: How far up the context hierarchy to propagate the flag.
None: Propagate to all parent contexts
0 (default): Only mark the current context
1: Mark the current context and its immediate parent
2+: Propagate up the specified number of levels
"""
self.agent_should_exit = True
# Propagate to parent context if it exists
# Propagate to parent context based on propagation_depth
if propagation_depth is None:
# Maintain current behavior of unlimited propagation
if self.parent:
self.parent.mark_should_exit()
self.parent.mark_should_exit(propagation_depth)
elif propagation_depth > 0:
# Propagate to parent with decremented depth
if self.parent:
self.parent.mark_should_exit(propagation_depth - 1)
# If propagation_depth is 0, don't propagate to parent
def mark_agent_crashed(self, message: str) -> None:
"""Mark the agent as crashed with the given message.
@ -216,11 +228,19 @@ def should_exit() -> bool:
return context.agent_should_exit if context else False
def mark_should_exit() -> None:
"""Mark that the agent should exit execution."""
def mark_should_exit(propagation_depth: Optional[int] = 0) -> None:
"""Mark that the agent should exit execution.
Args:
propagation_depth: How far up the context hierarchy to propagate the flag.
None: Propagate to all parent contexts
0 (default): Only mark the current context
1: Mark the current context and its immediate parent
2+: Propagate up the specified number of levels
"""
context = get_current_context()
if context:
context.mark_should_exit()
context.mark_should_exit(propagation_depth)
def is_crashed() -> bool:

View File

@ -14,12 +14,6 @@ from ra_aid.prompts.web_research_prompts import WEB_RESEARCH_PROMPT_SECTION_RESE
RESEARCH_COMMON_PROMPT_HEADER = """Current Date: {current_date}
<user query>
{base_task}
</user query>
KEEP IT SIMPLE
<previous research>
<key facts>
{key_facts}
@ -171,6 +165,11 @@ Decision on Implementation
If this is a top-level README.md or docs folder, start there.
If the user explicitly requested implementation, that means you should first perform all the background research for that task, then call request_implementation where the implementation will be carried out.
<user query>
{base_task}
</user query>
KEEP IT SIMPLE
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
@ -189,6 +188,12 @@ You have been spawned by a higher level research agent, so only spawn more resea
When you emit research notes, keep it extremely concise and relevant only to the specific research subquery you've been assigned.
<user query>
{base_task}
</user query>
KEEP IT SIMPLE
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
"""
)

View File

@ -29,17 +29,28 @@ class TestAgentShouldExit:
assert ctx.agent_should_exit is True
def test_propagation_to_parent_context(self):
"""Test that mark_should_exit propagates to parent contexts."""
"""Test that mark_should_exit propagates to parent contexts when specified."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
# Mark child as should exit
child.mark_should_exit()
# Mark child as should exit with propagation to all parents
child.mark_should_exit(propagation_depth=None)
# Verify both child and parent are marked
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
# Reset for the next test
parent.agent_should_exit = False
child.agent_should_exit = False
# Test default behavior (propagation_depth=0)
child.mark_should_exit()
# Verify only child is marked
assert child.agent_should_exit is True
assert parent.agent_should_exit is False
def test_nested_context_manager_propagation(self):
"""Test propagation with nested context managers."""
with agent_context() as outer:
@ -48,9 +59,139 @@ class TestAgentShouldExit:
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
# Mark inner as should exit
inner.mark_should_exit()
# Mark inner as should exit with explicit propagation to all parents
inner.mark_should_exit(propagation_depth=None)
# Both should now be True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True
# Test default behavior (propagation_depth=0)
with agent_context() as outer:
with agent_context() as inner:
# Initially both should be False
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
# Mark inner as should exit with default propagation
inner.mark_should_exit()
# Only inner should be True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is False
def test_mark_should_exit_propagation_depth(self):
"""Test that mark_should_exit respects propagation depth."""
# Create a hierarchy of contexts: ctx1 (root) -> ctx2 -> ctx3
ctx1 = AgentContext()
ctx2 = AgentContext(parent_context=ctx1)
ctx3 = AgentContext(parent_context=ctx2)
# Test case 1: propagation_depth=0 (only marks current context)
ctx3.mark_should_exit(propagation_depth=0)
assert ctx3.agent_should_exit is True
assert ctx2.agent_should_exit is False
assert ctx1.agent_should_exit is False
# Reset all contexts
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
ctx3.agent_should_exit = False
# Test case 2: propagation_depth=1 (marks current context and immediate parent)
ctx3.mark_should_exit(propagation_depth=1)
assert ctx3.agent_should_exit is True
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is False
# Reset all contexts
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
ctx3.agent_should_exit = False
# Test case 3: propagation_depth=2 (marks current context, parent, and grandparent)
ctx3.mark_should_exit(propagation_depth=2)
assert ctx3.agent_should_exit is True
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is True
# Reset all contexts
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
ctx3.agent_should_exit = False
# Test case 4: propagation_depth=None (marks all contexts)
ctx3.mark_should_exit(propagation_depth=None)
assert ctx3.agent_should_exit is True
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is True
# Reset all contexts
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
ctx3.agent_should_exit = False
# Test case 5: default behavior (propagation_depth=0)
ctx3.mark_should_exit() # Default is now 0
assert ctx3.agent_should_exit is True
assert ctx2.agent_should_exit is False
assert ctx1.agent_should_exit is False
def test_helper_mark_should_exit_propagation_depth(self):
"""Test that helper mark_should_exit function respects propagation depth."""
# Create a hierarchy of contexts
ctx1 = AgentContext()
ctx2 = AgentContext(parent_context=ctx1)
# Test with agent_context to set the current context
with agent_context(ctx2) as current_ctx:
# Test case 1: propagation_depth=0 (only marks current context)
mark_should_exit(propagation_depth=0)
assert current_ctx.agent_should_exit is True
assert ctx2.agent_should_exit is False # The context manager creates a new context
assert ctx1.agent_should_exit is False
# Reset for the next test
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
with agent_context(ctx2) as current_ctx:
# Test case 2: propagation_depth=1 (marks current context and immediate parent)
mark_should_exit(propagation_depth=1)
assert current_ctx.agent_should_exit is True
# The current_ctx's parent is ctx2, so it should be marked
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is False
# Reset for the next test
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
with agent_context(ctx2) as current_ctx:
# Test case 3: propagation_depth=2 (marks current context, parent, and grandparent)
mark_should_exit(propagation_depth=2)
assert current_ctx.agent_should_exit is True
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is True
# Reset for the next test
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
with agent_context(ctx2) as current_ctx:
# Test case 4: propagation_depth=None (marks all contexts)
mark_should_exit(propagation_depth=None)
assert current_ctx.agent_should_exit is True
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is True
# Reset for the next test
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
with agent_context(ctx2) as current_ctx:
# Test case 5: default behavior (propagation_depth=0)
mark_should_exit() # Default is now 0
assert current_ctx.agent_should_exit is True
assert ctx2.agent_should_exit is False
assert ctx1.agent_should_exit is False

View File

@ -277,7 +277,7 @@ def plan_implementation_completed(message: str) -> str:
Args:
message: Message explaining how the implementation plan was completed
"""
mark_should_exit()
mark_should_exit(propagation_depth=1)
mark_plan_completed(message)
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
log_work_event(f"Completed implementation:\n\n{message}")

View File

@ -120,7 +120,7 @@ class TestExitPropagation:
"""Test cases for the agent_should_exit flag propagation."""
def test_mark_should_exit_propagation(self):
"""Test that mark_should_exit propagates to parent contexts."""
"""Test that mark_should_exit propagates to parent contexts when requested."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
@ -128,15 +128,26 @@ class TestExitPropagation:
assert parent.agent_should_exit is False
assert child.agent_should_exit is False
# Mark the child context as should exit
child.mark_should_exit()
# Test with explicit propagation to all parents
child.mark_should_exit(propagation_depth=None)
# Both child and parent should now have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
# Reset for next test
parent.agent_should_exit = False
child.agent_should_exit = False
# Test default behavior (no propagation)
child.mark_should_exit()
# Only child should have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is False
def test_nested_should_exit_propagation(self):
"""Test that mark_should_exit propagates through multiple levels of parent contexts."""
"""Test that mark_should_exit propagates through multiple levels of parent contexts when requested."""
grandparent = AgentContext()
parent = AgentContext(parent_context=grandparent)
child = AgentContext(parent_context=parent)
@ -146,29 +157,56 @@ class TestExitPropagation:
assert parent.agent_should_exit is False
assert child.agent_should_exit is False
# Mark the child context as should exit
child.mark_should_exit()
# Test with explicit propagation to all parents
child.mark_should_exit(propagation_depth=None)
# All contexts should now have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
assert grandparent.agent_should_exit is True
# Reset for next test
grandparent.agent_should_exit = False
parent.agent_should_exit = False
child.agent_should_exit = False
# Test default behavior (no propagation)
child.mark_should_exit()
# Only child should have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is False
assert grandparent.agent_should_exit is False
def test_context_manager_should_exit_propagation(self):
"""Test that mark_should_exit propagates when using context managers."""
"""Test that mark_should_exit propagates when using context managers when requested."""
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
# Mark the inner context as should exit
inner.mark_should_exit()
# Test with explicit propagation to all parents
inner.mark_should_exit(propagation_depth=None)
# Both inner and outer should now have agent_should_exit as True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True
# Test default behavior (no propagation)
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
# Mark the inner context as should exit with default propagation
inner.mark_should_exit()
# Only inner should have agent_should_exit as True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is False
class TestCrashPropagation:
"""Test cases for the agent_has_crashed flag non-propagation."""
@ -344,10 +382,24 @@ class TestUtilityFunctions:
# Initially both contexts should have agent_should_exit as False
assert should_exit() is False
# Mark the current context (inner) as should exit
mark_should_exit()
# Test with explicit propagation to all parents
mark_should_exit(propagation_depth=None)
# Both inner and outer should now have agent_should_exit as True
assert should_exit() is True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True
# Test default behavior (no propagation)
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False
assert should_exit() is False
# Mark the current context (inner) as should exit with default propagation
mark_should_exit()
# Only inner should have agent_should_exit as True
assert should_exit() is True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is False

View File

@ -27,8 +27,8 @@ def test_ciayn_agent_stream_respects_should_exit():
# Test stream exits when should_exit is set
with agent_context() as ctx:
# Set should_exit
mark_should_exit()
# Set should_exit with propagation to all parents
mark_should_exit(propagation_depth=None)
# Verify should_exit is set
assert should_exit()
@ -80,8 +80,8 @@ def test_ciayn_agent_bundled_tools_respects_should_exit():
# Set up test messages
messages = {"messages": [HumanMessage(content="test")]}
# Set should_exit before calling stream
mark_should_exit()
# Set should_exit with propagation to all parents
mark_should_exit(propagation_depth=None)
# Call stream
generator = agent.stream(messages)
@ -128,7 +128,7 @@ def test_ciayn_agent_single_tool_respects_should_exit():
def execute_and_exit(*args, **kwargs):
# Set should_exit flag
with agent_context() as ctx:
mark_should_exit()
mark_should_exit(propagation_depth=None)
return "Tool executed"
# Override the mock tool to set should_exit
@ -176,8 +176,8 @@ def test_ciayn_agent_execute_tool_respects_should_exit():
# Test _execute_tool exits when should_exit is set
with agent_context() as ctx:
# Set should_exit
mark_should_exit()
# Set should_exit with propagation to all parents
mark_should_exit(propagation_depth=None)
# Call _execute_tool
message = HumanMessage(content="test_tool()")