diff --git a/ra_aid/agent_context.py b/ra_aid/agent_context.py
index c3b11e7..c6d76ba 100644
--- a/ra_aid/agent_context.py
+++ b/ra_aid/agent_context.py
@@ -56,16 +56,28 @@ class AgentContext:
self.plan_completed = False
self.completion_message = ""
- def mark_should_exit(self) -> None:
+ def mark_should_exit(self, propagation_depth: Optional[int] = 0) -> None:
"""Mark that the agent should exit execution.
- This propagates the exit state to all parent contexts.
+ Args:
+ propagation_depth: How far up the context hierarchy to propagate the flag.
+ None: Propagate to all parent contexts
+ 0 (default): Only mark the current context
+ 1: Mark the current context and its immediate parent
+ 2+: Propagate up the specified number of levels
"""
self.agent_should_exit = True
- # Propagate to parent context if it exists
- if self.parent:
- self.parent.mark_should_exit()
+ # Propagate to parent context based on propagation_depth
+ if propagation_depth is None:
+ # Maintain current behavior of unlimited propagation
+ if self.parent:
+ self.parent.mark_should_exit(propagation_depth)
+ elif propagation_depth > 0:
+ # Propagate to parent with decremented depth
+ if self.parent:
+ self.parent.mark_should_exit(propagation_depth - 1)
+ # If propagation_depth is 0, don't propagate to parent
def mark_agent_crashed(self, message: str) -> None:
"""Mark the agent as crashed with the given message.
@@ -216,11 +228,19 @@ def should_exit() -> bool:
return context.agent_should_exit if context else False
-def mark_should_exit() -> None:
- """Mark that the agent should exit execution."""
+def mark_should_exit(propagation_depth: Optional[int] = 0) -> None:
+ """Mark that the agent should exit execution.
+
+ Args:
+ propagation_depth: How far up the context hierarchy to propagate the flag.
+ None: Propagate to all parent contexts
+ 0 (default): Only mark the current context
+ 1: Mark the current context and its immediate parent
+ 2+: Propagate up the specified number of levels
+ """
context = get_current_context()
if context:
- context.mark_should_exit()
+ context.mark_should_exit(propagation_depth)
def is_crashed() -> bool:
diff --git a/ra_aid/prompts/research_prompts.py b/ra_aid/prompts/research_prompts.py
index 3540505..8780cbe 100644
--- a/ra_aid/prompts/research_prompts.py
+++ b/ra_aid/prompts/research_prompts.py
@@ -14,12 +14,6 @@ from ra_aid.prompts.web_research_prompts import WEB_RESEARCH_PROMPT_SECTION_RESE
RESEARCH_COMMON_PROMPT_HEADER = """Current Date: {current_date}
-
-{base_task}
-
-
-KEEP IT SIMPLE
-
{key_facts}
@@ -171,6 +165,11 @@ Decision on Implementation
If this is a top-level README.md or docs folder, start there.
If the user explicitly requested implementation, that means you should first perform all the background research for that task, then call request_implementation where the implementation will be carried out.
+
+{base_task}
+
+
+KEEP IT SIMPLE
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
@@ -189,6 +188,12 @@ You have been spawned by a higher level research agent, so only spawn more resea
When you emit research notes, keep it extremely concise and relevant only to the specific research subquery you've been assigned.
+
+{base_task}
+
+
+KEEP IT SIMPLE
+
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
"""
)
\ No newline at end of file
diff --git a/ra_aid/tests/ra_aid/test_agent_should_exit.py b/ra_aid/tests/ra_aid/test_agent_should_exit.py
index 055f09b..8b51902 100644
--- a/ra_aid/tests/ra_aid/test_agent_should_exit.py
+++ b/ra_aid/tests/ra_aid/test_agent_should_exit.py
@@ -29,16 +29,27 @@ class TestAgentShouldExit:
assert ctx.agent_should_exit is True
def test_propagation_to_parent_context(self):
- """Test that mark_should_exit propagates to parent contexts."""
+ """Test that mark_should_exit propagates to parent contexts when specified."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
- # Mark child as should exit
- child.mark_should_exit()
+ # Mark child as should exit with propagation to all parents
+ child.mark_should_exit(propagation_depth=None)
# Verify both child and parent are marked
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
+
+ # Reset for the next test
+ parent.agent_should_exit = False
+ child.agent_should_exit = False
+
+ # Test default behavior (propagation_depth=0)
+ child.mark_should_exit()
+
+ # Verify only child is marked
+ assert child.agent_should_exit is True
+ assert parent.agent_should_exit is False
def test_nested_context_manager_propagation(self):
"""Test propagation with nested context managers."""
@@ -48,9 +59,139 @@ class TestAgentShouldExit:
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
- # Mark inner as should exit
- inner.mark_should_exit()
+ # Mark inner as should exit with explicit propagation to all parents
+ inner.mark_should_exit(propagation_depth=None)
# Both should now be True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True
+
+ # Test default behavior (propagation_depth=0)
+ with agent_context() as outer:
+ with agent_context() as inner:
+ # Initially both should be False
+ assert outer.agent_should_exit is False
+ assert inner.agent_should_exit is False
+
+ # Mark inner as should exit with default propagation
+ inner.mark_should_exit()
+
+ # Only inner should be True
+ assert inner.agent_should_exit is True
+ assert outer.agent_should_exit is False
+
+ def test_mark_should_exit_propagation_depth(self):
+ """Test that mark_should_exit respects propagation depth."""
+ # Create a hierarchy of contexts: ctx1 (root) -> ctx2 -> ctx3
+ ctx1 = AgentContext()
+ ctx2 = AgentContext(parent_context=ctx1)
+ ctx3 = AgentContext(parent_context=ctx2)
+
+ # Test case 1: propagation_depth=0 (only marks current context)
+ ctx3.mark_should_exit(propagation_depth=0)
+ assert ctx3.agent_should_exit is True
+ assert ctx2.agent_should_exit is False
+ assert ctx1.agent_should_exit is False
+
+ # Reset all contexts
+ ctx1.agent_should_exit = False
+ ctx2.agent_should_exit = False
+ ctx3.agent_should_exit = False
+
+ # Test case 2: propagation_depth=1 (marks current context and immediate parent)
+ ctx3.mark_should_exit(propagation_depth=1)
+ assert ctx3.agent_should_exit is True
+ assert ctx2.agent_should_exit is True
+ assert ctx1.agent_should_exit is False
+
+ # Reset all contexts
+ ctx1.agent_should_exit = False
+ ctx2.agent_should_exit = False
+ ctx3.agent_should_exit = False
+
+ # Test case 3: propagation_depth=2 (marks current context, parent, and grandparent)
+ ctx3.mark_should_exit(propagation_depth=2)
+ assert ctx3.agent_should_exit is True
+ assert ctx2.agent_should_exit is True
+ assert ctx1.agent_should_exit is True
+
+ # Reset all contexts
+ ctx1.agent_should_exit = False
+ ctx2.agent_should_exit = False
+ ctx3.agent_should_exit = False
+
+ # Test case 4: propagation_depth=None (marks all contexts)
+ ctx3.mark_should_exit(propagation_depth=None)
+ assert ctx3.agent_should_exit is True
+ assert ctx2.agent_should_exit is True
+ assert ctx1.agent_should_exit is True
+
+ # Reset all contexts
+ ctx1.agent_should_exit = False
+ ctx2.agent_should_exit = False
+ ctx3.agent_should_exit = False
+
+ # Test case 5: default behavior (propagation_depth=0)
+ ctx3.mark_should_exit() # Default is now 0
+ assert ctx3.agent_should_exit is True
+ assert ctx2.agent_should_exit is False
+ assert ctx1.agent_should_exit is False
+
+ def test_helper_mark_should_exit_propagation_depth(self):
+ """Test that helper mark_should_exit function respects propagation depth."""
+ # Create a hierarchy of contexts
+ ctx1 = AgentContext()
+ ctx2 = AgentContext(parent_context=ctx1)
+
+ # Test with agent_context to set the current context
+ with agent_context(ctx2) as current_ctx:
+ # Test case 1: propagation_depth=0 (only marks current context)
+ mark_should_exit(propagation_depth=0)
+ assert current_ctx.agent_should_exit is True
+ assert ctx2.agent_should_exit is False # The context manager creates a new context
+ assert ctx1.agent_should_exit is False
+
+ # Reset for the next test
+ ctx1.agent_should_exit = False
+ ctx2.agent_should_exit = False
+
+ with agent_context(ctx2) as current_ctx:
+ # Test case 2: propagation_depth=1 (marks current context and immediate parent)
+ mark_should_exit(propagation_depth=1)
+ assert current_ctx.agent_should_exit is True
+ # The current_ctx's parent is ctx2, so it should be marked
+ assert ctx2.agent_should_exit is True
+ assert ctx1.agent_should_exit is False
+
+ # Reset for the next test
+ ctx1.agent_should_exit = False
+ ctx2.agent_should_exit = False
+
+ with agent_context(ctx2) as current_ctx:
+ # Test case 3: propagation_depth=2 (marks current context, parent, and grandparent)
+ mark_should_exit(propagation_depth=2)
+ assert current_ctx.agent_should_exit is True
+ assert ctx2.agent_should_exit is True
+ assert ctx1.agent_should_exit is True
+
+ # Reset for the next test
+ ctx1.agent_should_exit = False
+ ctx2.agent_should_exit = False
+
+ with agent_context(ctx2) as current_ctx:
+ # Test case 4: propagation_depth=None (marks all contexts)
+ mark_should_exit(propagation_depth=None)
+ assert current_ctx.agent_should_exit is True
+ assert ctx2.agent_should_exit is True
+ assert ctx1.agent_should_exit is True
+
+ # Reset for the next test
+ ctx1.agent_should_exit = False
+ ctx2.agent_should_exit = False
+
+ with agent_context(ctx2) as current_ctx:
+ # Test case 5: default behavior (propagation_depth=0)
+ mark_should_exit() # Default is now 0
+ assert current_ctx.agent_should_exit is True
+ assert ctx2.agent_should_exit is False
+ assert ctx1.agent_should_exit is False
diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py
index e5d86b1..c334b4b 100644
--- a/ra_aid/tools/memory.py
+++ b/ra_aid/tools/memory.py
@@ -277,7 +277,7 @@ def plan_implementation_completed(message: str) -> str:
Args:
message: Message explaining how the implementation plan was completed
"""
- mark_should_exit()
+ mark_should_exit(propagation_depth=1)
mark_plan_completed(message)
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
log_work_event(f"Completed implementation:\n\n{message}")
diff --git a/tests/ra_aid/test_agent_context.py b/tests/ra_aid/test_agent_context.py
index 4d6b1fd..8da78cd 100644
--- a/tests/ra_aid/test_agent_context.py
+++ b/tests/ra_aid/test_agent_context.py
@@ -120,7 +120,7 @@ class TestExitPropagation:
"""Test cases for the agent_should_exit flag propagation."""
def test_mark_should_exit_propagation(self):
- """Test that mark_should_exit propagates to parent contexts."""
+ """Test that mark_should_exit propagates to parent contexts when requested."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
@@ -128,15 +128,26 @@ class TestExitPropagation:
assert parent.agent_should_exit is False
assert child.agent_should_exit is False
- # Mark the child context as should exit
- child.mark_should_exit()
+ # Test with explicit propagation to all parents
+ child.mark_should_exit(propagation_depth=None)
# Both child and parent should now have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
+
+ # Reset for next test
+ parent.agent_should_exit = False
+ child.agent_should_exit = False
+
+ # Test default behavior (no propagation)
+ child.mark_should_exit()
+
+ # Only child should have agent_should_exit as True
+ assert child.agent_should_exit is True
+ assert parent.agent_should_exit is False
def test_nested_should_exit_propagation(self):
- """Test that mark_should_exit propagates through multiple levels of parent contexts."""
+ """Test that mark_should_exit propagates through multiple levels of parent contexts when requested."""
grandparent = AgentContext()
parent = AgentContext(parent_context=grandparent)
child = AgentContext(parent_context=parent)
@@ -146,28 +157,55 @@ class TestExitPropagation:
assert parent.agent_should_exit is False
assert child.agent_should_exit is False
- # Mark the child context as should exit
- child.mark_should_exit()
+ # Test with explicit propagation to all parents
+ child.mark_should_exit(propagation_depth=None)
# All contexts should now have agent_should_exit as True
assert child.agent_should_exit is True
assert parent.agent_should_exit is True
assert grandparent.agent_should_exit is True
+
+ # Reset for next test
+ grandparent.agent_should_exit = False
+ parent.agent_should_exit = False
+ child.agent_should_exit = False
+
+ # Test default behavior (no propagation)
+ child.mark_should_exit()
+
+ # Only child should have agent_should_exit as True
+ assert child.agent_should_exit is True
+ assert parent.agent_should_exit is False
+ assert grandparent.agent_should_exit is False
def test_context_manager_should_exit_propagation(self):
- """Test that mark_should_exit propagates when using context managers."""
+ """Test that mark_should_exit propagates when using context managers when requested."""
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_should_exit as False
assert outer.agent_should_exit is False
assert inner.agent_should_exit is False
- # Mark the inner context as should exit
- inner.mark_should_exit()
+ # Test with explicit propagation to all parents
+ inner.mark_should_exit(propagation_depth=None)
# Both inner and outer should now have agent_should_exit as True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True
+
+ # Test default behavior (no propagation)
+ with agent_context() as outer:
+ with agent_context() as inner:
+ # Initially both contexts should have agent_should_exit as False
+ assert outer.agent_should_exit is False
+ assert inner.agent_should_exit is False
+
+ # Mark the inner context as should exit with default propagation
+ inner.mark_should_exit()
+
+ # Only inner should have agent_should_exit as True
+ assert inner.agent_should_exit is True
+ assert outer.agent_should_exit is False
class TestCrashPropagation:
@@ -344,10 +382,24 @@ class TestUtilityFunctions:
# Initially both contexts should have agent_should_exit as False
assert should_exit() is False
- # Mark the current context (inner) as should exit
- mark_should_exit()
+ # Test with explicit propagation to all parents
+ mark_should_exit(propagation_depth=None)
# Both inner and outer should now have agent_should_exit as True
assert should_exit() is True
assert inner.agent_should_exit is True
assert outer.agent_should_exit is True
+
+ # Test default behavior (no propagation)
+ with agent_context() as outer:
+ with agent_context() as inner:
+ # Initially both contexts should have agent_should_exit as False
+ assert should_exit() is False
+
+ # Mark the current context (inner) as should exit with default propagation
+ mark_should_exit()
+
+ # Only inner should have agent_should_exit as True
+ assert should_exit() is True
+ assert inner.agent_should_exit is True
+ assert outer.agent_should_exit is False
diff --git a/tests/ra_aid/test_agent_should_exit_ciayn.py b/tests/ra_aid/test_agent_should_exit_ciayn.py
index 150b184..51bf536 100644
--- a/tests/ra_aid/test_agent_should_exit_ciayn.py
+++ b/tests/ra_aid/test_agent_should_exit_ciayn.py
@@ -27,8 +27,8 @@ def test_ciayn_agent_stream_respects_should_exit():
# Test stream exits when should_exit is set
with agent_context() as ctx:
- # Set should_exit
- mark_should_exit()
+ # Set should_exit with propagation to all parents
+ mark_should_exit(propagation_depth=None)
# Verify should_exit is set
assert should_exit()
@@ -80,8 +80,8 @@ def test_ciayn_agent_bundled_tools_respects_should_exit():
# Set up test messages
messages = {"messages": [HumanMessage(content="test")]}
- # Set should_exit before calling stream
- mark_should_exit()
+ # Set should_exit with propagation to all parents
+ mark_should_exit(propagation_depth=None)
# Call stream
generator = agent.stream(messages)
@@ -128,7 +128,7 @@ def test_ciayn_agent_single_tool_respects_should_exit():
def execute_and_exit(*args, **kwargs):
# Set should_exit flag
with agent_context() as ctx:
- mark_should_exit()
+ mark_should_exit(propagation_depth=None)
return "Tool executed"
# Override the mock tool to set should_exit
@@ -176,8 +176,8 @@ def test_ciayn_agent_execute_tool_respects_should_exit():
# Test _execute_tool exits when should_exit is set
with agent_context() as ctx:
- # Set should_exit
- mark_should_exit()
+ # Set should_exit with propagation to all parents
+ mark_should_exit(propagation_depth=None)
# Call _execute_tool
message = HumanMessage(content="test_tool()")