should exit propagation

This commit is contained in:
AI Christianson 2025-03-07 11:25:46 -05:00
parent 77856bfa0c
commit 9e9c3ad3d2
6 changed files with 256 additions and 38 deletions

View File

@ -56,16 +56,28 @@ class AgentContext:
self.plan_completed = False self.plan_completed = False
self.completion_message = "" self.completion_message = ""
def mark_should_exit(self) -> None: def mark_should_exit(self, propagation_depth: Optional[int] = 0) -> None:
"""Mark that the agent should exit execution. """Mark that the agent should exit execution.
This propagates the exit state to all parent contexts. Args:
propagation_depth: How far up the context hierarchy to propagate the flag.
None: Propagate to all parent contexts
0 (default): Only mark the current context
1: Mark the current context and its immediate parent
2+: Propagate up the specified number of levels
""" """
self.agent_should_exit = True self.agent_should_exit = True
# Propagate to parent context if it exists # Propagate to parent context based on propagation_depth
if self.parent: if propagation_depth is None:
self.parent.mark_should_exit() # Maintain current behavior of unlimited propagation
if self.parent:
self.parent.mark_should_exit(propagation_depth)
elif propagation_depth > 0:
# Propagate to parent with decremented depth
if self.parent:
self.parent.mark_should_exit(propagation_depth - 1)
# If propagation_depth is 0, don't propagate to parent
def mark_agent_crashed(self, message: str) -> None: def mark_agent_crashed(self, message: str) -> None:
"""Mark the agent as crashed with the given message. """Mark the agent as crashed with the given message.
@ -216,11 +228,19 @@ def should_exit() -> bool:
return context.agent_should_exit if context else False return context.agent_should_exit if context else False
def mark_should_exit() -> None: def mark_should_exit(propagation_depth: Optional[int] = 0) -> None:
"""Mark that the agent should exit execution.""" """Mark that the agent should exit execution.
Args:
propagation_depth: How far up the context hierarchy to propagate the flag.
None: Propagate to all parent contexts
0 (default): Only mark the current context
1: Mark the current context and its immediate parent
2+: Propagate up the specified number of levels
"""
context = get_current_context() context = get_current_context()
if context: if context:
context.mark_should_exit() context.mark_should_exit(propagation_depth)
def is_crashed() -> bool: def is_crashed() -> bool:

View File

@ -14,12 +14,6 @@ from ra_aid.prompts.web_research_prompts import WEB_RESEARCH_PROMPT_SECTION_RESE
RESEARCH_COMMON_PROMPT_HEADER = """Current Date: {current_date} RESEARCH_COMMON_PROMPT_HEADER = """Current Date: {current_date}
<user query>
{base_task}
</user query>
KEEP IT SIMPLE
<previous research> <previous research>
<key facts> <key facts>
{key_facts} {key_facts}
@ -171,6 +165,11 @@ Decision on Implementation
If this is a top-level README.md or docs folder, start there. If this is a top-level README.md or docs folder, start there.
If the user explicitly requested implementation, that means you should first perform all the background research for that task, then call request_implementation where the implementation will be carried out. If the user explicitly requested implementation, that means you should first perform all the background research for that task, then call request_implementation where the implementation will be carried out.
<user query>
{base_task}
</user query>
KEEP IT SIMPLE
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT! NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
@ -189,6 +188,12 @@ You have been spawned by a higher level research agent, so only spawn more resea
When you emit research notes, keep it extremely concise and relevant only to the specific research subquery you've been assigned. When you emit research notes, keep it extremely concise and relevant only to the specific research subquery you've been assigned.
<user query>
{base_task}
</user query>
KEEP IT SIMPLE
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT! NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
""" """
) )

View File

@ -29,16 +29,27 @@ class TestAgentShouldExit:
assert ctx.agent_should_exit is True assert ctx.agent_should_exit is True
def test_propagation_to_parent_context(self): def test_propagation_to_parent_context(self):
"""Test that mark_should_exit propagates to parent contexts.""" """Test that mark_should_exit propagates to parent contexts when specified."""
parent = AgentContext() parent = AgentContext()
child = AgentContext(parent_context=parent) child = AgentContext(parent_context=parent)
# Mark child as should exit # Mark child as should exit with propagation to all parents
child.mark_should_exit() child.mark_should_exit(propagation_depth=None)
# Verify both child and parent are marked # Verify both child and parent are marked
assert child.agent_should_exit is True assert child.agent_should_exit is True
assert parent.agent_should_exit is True assert parent.agent_should_exit is True
# Reset for the next test
parent.agent_should_exit = False
child.agent_should_exit = False
# Test default behavior (propagation_depth=0)
child.mark_should_exit()
# Verify only child is marked
assert child.agent_should_exit is True
assert parent.agent_should_exit is False
def test_nested_context_manager_propagation(self): def test_nested_context_manager_propagation(self):
"""Test propagation with nested context managers.""" """Test propagation with nested context managers."""
@ -48,9 +59,139 @@ class TestAgentShouldExit:
assert outer.agent_should_exit is False assert outer.agent_should_exit is False
assert inner.agent_should_exit is False assert inner.agent_should_exit is False
# Mark inner as should exit # Mark inner as should exit with explicit propagation to all parents
inner.mark_should_exit() inner.mark_should_exit(propagation_depth=None)
# Both should now be True # Both should now be True
assert inner.agent_should_exit is True assert inner.agent_should_exit is True
assert outer.agent_should_exit is True assert outer.agent_should_exit is True
# Test default behavior (propagation_depth=0)
with agent_context() as outer:
with agent_context() as inner:
# Initially both should be False
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
# Mark inner as should exit with default propagation
inner.mark_should_exit()
# Only inner should be True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is False
def test_mark_should_exit_propagation_depth(self):
"""Test that mark_should_exit respects propagation depth."""
# Create a hierarchy of contexts: ctx1 (root) -> ctx2 -> ctx3
ctx1 = AgentContext()
ctx2 = AgentContext(parent_context=ctx1)
ctx3 = AgentContext(parent_context=ctx2)
# Test case 1: propagation_depth=0 (only marks current context)
ctx3.mark_should_exit(propagation_depth=0)
assert ctx3.agent_should_exit is True
assert ctx2.agent_should_exit is False
assert ctx1.agent_should_exit is False
# Reset all contexts
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
ctx3.agent_should_exit = False
# Test case 2: propagation_depth=1 (marks current context and immediate parent)
ctx3.mark_should_exit(propagation_depth=1)
assert ctx3.agent_should_exit is True
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is False
# Reset all contexts
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
ctx3.agent_should_exit = False
# Test case 3: propagation_depth=2 (marks current context, parent, and grandparent)
ctx3.mark_should_exit(propagation_depth=2)
assert ctx3.agent_should_exit is True
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is True
# Reset all contexts
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
ctx3.agent_should_exit = False
# Test case 4: propagation_depth=None (marks all contexts)
ctx3.mark_should_exit(propagation_depth=None)
assert ctx3.agent_should_exit is True
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is True
# Reset all contexts
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
ctx3.agent_should_exit = False
# Test case 5: default behavior (propagation_depth=0)
ctx3.mark_should_exit() # Default is now 0
assert ctx3.agent_should_exit is True
assert ctx2.agent_should_exit is False
assert ctx1.agent_should_exit is False
def test_helper_mark_should_exit_propagation_depth(self):
"""Test that helper mark_should_exit function respects propagation depth."""
# Create a hierarchy of contexts
ctx1 = AgentContext()
ctx2 = AgentContext(parent_context=ctx1)
# Test with agent_context to set the current context
with agent_context(ctx2) as current_ctx:
# Test case 1: propagation_depth=0 (only marks current context)
mark_should_exit(propagation_depth=0)
assert current_ctx.agent_should_exit is True
assert ctx2.agent_should_exit is False # The context manager creates a new context
assert ctx1.agent_should_exit is False
# Reset for the next test
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
with agent_context(ctx2) as current_ctx:
# Test case 2: propagation_depth=1 (marks current context and immediate parent)
mark_should_exit(propagation_depth=1)
assert current_ctx.agent_should_exit is True
# The current_ctx's parent is ctx2, so it should be marked
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is False
# Reset for the next test
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
with agent_context(ctx2) as current_ctx:
# Test case 3: propagation_depth=2 (marks current context, parent, and grandparent)
mark_should_exit(propagation_depth=2)
assert current_ctx.agent_should_exit is True
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is True
# Reset for the next test
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
with agent_context(ctx2) as current_ctx:
# Test case 4: propagation_depth=None (marks all contexts)
mark_should_exit(propagation_depth=None)
assert current_ctx.agent_should_exit is True
assert ctx2.agent_should_exit is True
assert ctx1.agent_should_exit is True
# Reset for the next test
ctx1.agent_should_exit = False
ctx2.agent_should_exit = False
with agent_context(ctx2) as current_ctx:
# Test case 5: default behavior (propagation_depth=0)
mark_should_exit() # Default is now 0
assert current_ctx.agent_should_exit is True
assert ctx2.agent_should_exit is False
assert ctx1.agent_should_exit is False

View File

@ -277,7 +277,7 @@ def plan_implementation_completed(message: str) -> str:
Args: Args:
message: Message explaining how the implementation plan was completed message: Message explaining how the implementation plan was completed
""" """
mark_should_exit() mark_should_exit(propagation_depth=1)
mark_plan_completed(message) mark_plan_completed(message)
console.print(Panel(Markdown(message), title="✅ Plan Executed")) console.print(Panel(Markdown(message), title="✅ Plan Executed"))
log_work_event(f"Completed implementation:\n\n{message}") log_work_event(f"Completed implementation:\n\n{message}")

View File

@ -120,7 +120,7 @@ class TestExitPropagation:
"""Test cases for the agent_should_exit flag propagation.""" """Test cases for the agent_should_exit flag propagation."""
def test_mark_should_exit_propagation(self): def test_mark_should_exit_propagation(self):
"""Test that mark_should_exit propagates to parent contexts.""" """Test that mark_should_exit propagates to parent contexts when requested."""
parent = AgentContext() parent = AgentContext()
child = AgentContext(parent_context=parent) child = AgentContext(parent_context=parent)
@ -128,15 +128,26 @@ class TestExitPropagation:
assert parent.agent_should_exit is False assert parent.agent_should_exit is False
assert child.agent_should_exit is False assert child.agent_should_exit is False
# Mark the child context as should exit # Test with explicit propagation to all parents
child.mark_should_exit() child.mark_should_exit(propagation_depth=None)
# Both child and parent should now have agent_should_exit as True # Both child and parent should now have agent_should_exit as True
assert child.agent_should_exit is True assert child.agent_should_exit is True
assert parent.agent_should_exit is True assert parent.agent_should_exit is True
# Reset for next test
parent.agent_should_exit = False
child.agent_should_exit = False
# Test default behavior (no propagation)
child.mark_should_exit()
# Only child should have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is False
def test_nested_should_exit_propagation(self): def test_nested_should_exit_propagation(self):
"""Test that mark_should_exit propagates through multiple levels of parent contexts.""" """Test that mark_should_exit propagates through multiple levels of parent contexts when requested."""
grandparent = AgentContext() grandparent = AgentContext()
parent = AgentContext(parent_context=grandparent) parent = AgentContext(parent_context=grandparent)
child = AgentContext(parent_context=parent) child = AgentContext(parent_context=parent)
@ -146,28 +157,55 @@ class TestExitPropagation:
assert parent.agent_should_exit is False assert parent.agent_should_exit is False
assert child.agent_should_exit is False assert child.agent_should_exit is False
# Mark the child context as should exit # Test with explicit propagation to all parents
child.mark_should_exit() child.mark_should_exit(propagation_depth=None)
# All contexts should now have agent_should_exit as True # All contexts should now have agent_should_exit as True
assert child.agent_should_exit is True assert child.agent_should_exit is True
assert parent.agent_should_exit is True assert parent.agent_should_exit is True
assert grandparent.agent_should_exit is True assert grandparent.agent_should_exit is True
# Reset for next test
grandparent.agent_should_exit = False
parent.agent_should_exit = False
child.agent_should_exit = False
# Test default behavior (no propagation)
child.mark_should_exit()
# Only child should have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is False
assert grandparent.agent_should_exit is False
def test_context_manager_should_exit_propagation(self): def test_context_manager_should_exit_propagation(self):
"""Test that mark_should_exit propagates when using context managers.""" """Test that mark_should_exit propagates when using context managers when requested."""
with agent_context() as outer: with agent_context() as outer:
with agent_context() as inner: with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False # Initially both contexts should have agent_should_exit as False
assert outer.agent_should_exit is False assert outer.agent_should_exit is False
assert inner.agent_should_exit is False assert inner.agent_should_exit is False
# Mark the inner context as should exit # Test with explicit propagation to all parents
inner.mark_should_exit() inner.mark_should_exit(propagation_depth=None)
# Both inner and outer should now have agent_should_exit as True # Both inner and outer should now have agent_should_exit as True
assert inner.agent_should_exit is True assert inner.agent_should_exit is True
assert outer.agent_should_exit is True assert outer.agent_should_exit is True
# Test default behavior (no propagation)
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
# Mark the inner context as should exit with default propagation
inner.mark_should_exit()
# Only inner should have agent_should_exit as True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is False
class TestCrashPropagation: class TestCrashPropagation:
@ -344,10 +382,24 @@ class TestUtilityFunctions:
# Initially both contexts should have agent_should_exit as False # Initially both contexts should have agent_should_exit as False
assert should_exit() is False assert should_exit() is False
# Mark the current context (inner) as should exit # Test with explicit propagation to all parents
mark_should_exit() mark_should_exit(propagation_depth=None)
# Both inner and outer should now have agent_should_exit as True # Both inner and outer should now have agent_should_exit as True
assert should_exit() is True assert should_exit() is True
assert inner.agent_should_exit is True assert inner.agent_should_exit is True
assert outer.agent_should_exit is True assert outer.agent_should_exit is True
# Test default behavior (no propagation)
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False
assert should_exit() is False
# Mark the current context (inner) as should exit with default propagation
mark_should_exit()
# Only inner should have agent_should_exit as True
assert should_exit() is True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is False

View File

@ -27,8 +27,8 @@ def test_ciayn_agent_stream_respects_should_exit():
# Test stream exits when should_exit is set # Test stream exits when should_exit is set
with agent_context() as ctx: with agent_context() as ctx:
# Set should_exit # Set should_exit with propagation to all parents
mark_should_exit() mark_should_exit(propagation_depth=None)
# Verify should_exit is set # Verify should_exit is set
assert should_exit() assert should_exit()
@ -80,8 +80,8 @@ def test_ciayn_agent_bundled_tools_respects_should_exit():
# Set up test messages # Set up test messages
messages = {"messages": [HumanMessage(content="test")]} messages = {"messages": [HumanMessage(content="test")]}
# Set should_exit before calling stream # Set should_exit with propagation to all parents
mark_should_exit() mark_should_exit(propagation_depth=None)
# Call stream # Call stream
generator = agent.stream(messages) generator = agent.stream(messages)
@ -128,7 +128,7 @@ def test_ciayn_agent_single_tool_respects_should_exit():
def execute_and_exit(*args, **kwargs): def execute_and_exit(*args, **kwargs):
# Set should_exit flag # Set should_exit flag
with agent_context() as ctx: with agent_context() as ctx:
mark_should_exit() mark_should_exit(propagation_depth=None)
return "Tool executed" return "Tool executed"
# Override the mock tool to set should_exit # Override the mock tool to set should_exit
@ -176,8 +176,8 @@ def test_ciayn_agent_execute_tool_respects_should_exit():
# Test _execute_tool exits when should_exit is set # Test _execute_tool exits when should_exit is set
with agent_context() as ctx: with agent_context() as ctx:
# Set should_exit # Set should_exit with propagation to all parents
mark_should_exit() mark_should_exit(propagation_depth=None)
# Call _execute_tool # Call _execute_tool
message = HumanMessage(content="test_tool()") message = HumanMessage(content="test_tool()")