should exit propagation
This commit is contained in:
parent
77856bfa0c
commit
9e9c3ad3d2
|
|
@ -56,16 +56,28 @@ class AgentContext:
|
|||
self.plan_completed = False
|
||||
self.completion_message = ""
|
||||
|
||||
def mark_should_exit(self) -> None:
|
||||
def mark_should_exit(self, propagation_depth: Optional[int] = 0) -> None:
|
||||
"""Mark that the agent should exit execution.
|
||||
|
||||
This propagates the exit state to all parent contexts.
|
||||
Args:
|
||||
propagation_depth: How far up the context hierarchy to propagate the flag.
|
||||
None: Propagate to all parent contexts
|
||||
0 (default): Only mark the current context
|
||||
1: Mark the current context and its immediate parent
|
||||
2+: Propagate up the specified number of levels
|
||||
"""
|
||||
self.agent_should_exit = True
|
||||
|
||||
# Propagate to parent context if it exists
|
||||
if self.parent:
|
||||
self.parent.mark_should_exit()
|
||||
# Propagate to parent context based on propagation_depth
|
||||
if propagation_depth is None:
|
||||
# Maintain current behavior of unlimited propagation
|
||||
if self.parent:
|
||||
self.parent.mark_should_exit(propagation_depth)
|
||||
elif propagation_depth > 0:
|
||||
# Propagate to parent with decremented depth
|
||||
if self.parent:
|
||||
self.parent.mark_should_exit(propagation_depth - 1)
|
||||
# If propagation_depth is 0, don't propagate to parent
|
||||
|
||||
def mark_agent_crashed(self, message: str) -> None:
|
||||
"""Mark the agent as crashed with the given message.
|
||||
|
|
@ -216,11 +228,19 @@ def should_exit() -> bool:
|
|||
return context.agent_should_exit if context else False
|
||||
|
||||
|
||||
def mark_should_exit() -> None:
|
||||
"""Mark that the agent should exit execution."""
|
||||
def mark_should_exit(propagation_depth: Optional[int] = 0) -> None:
|
||||
"""Mark that the agent should exit execution.
|
||||
|
||||
Args:
|
||||
propagation_depth: How far up the context hierarchy to propagate the flag.
|
||||
None: Propagate to all parent contexts
|
||||
0 (default): Only mark the current context
|
||||
1: Mark the current context and its immediate parent
|
||||
2+: Propagate up the specified number of levels
|
||||
"""
|
||||
context = get_current_context()
|
||||
if context:
|
||||
context.mark_should_exit()
|
||||
context.mark_should_exit(propagation_depth)
|
||||
|
||||
|
||||
def is_crashed() -> bool:
|
||||
|
|
|
|||
|
|
@ -14,12 +14,6 @@ from ra_aid.prompts.web_research_prompts import WEB_RESEARCH_PROMPT_SECTION_RESE
|
|||
|
||||
RESEARCH_COMMON_PROMPT_HEADER = """Current Date: {current_date}
|
||||
|
||||
<user query>
|
||||
{base_task}
|
||||
</user query>
|
||||
|
||||
KEEP IT SIMPLE
|
||||
|
||||
<previous research>
|
||||
<key facts>
|
||||
{key_facts}
|
||||
|
|
@ -171,6 +165,11 @@ Decision on Implementation
|
|||
If this is a top-level README.md or docs folder, start there.
|
||||
|
||||
If the user explicitly requested implementation, that means you should first perform all the background research for that task, then call request_implementation where the implementation will be carried out.
|
||||
<user query>
|
||||
{base_task}
|
||||
</user query>
|
||||
|
||||
KEEP IT SIMPLE
|
||||
|
||||
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
||||
|
||||
|
|
@ -189,6 +188,12 @@ You have been spawned by a higher level research agent, so only spawn more resea
|
|||
|
||||
When you emit research notes, keep it extremely concise and relevant only to the specific research subquery you've been assigned.
|
||||
|
||||
<user query>
|
||||
{base_task}
|
||||
</user query>
|
||||
|
||||
KEEP IT SIMPLE
|
||||
|
||||
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
||||
"""
|
||||
)
|
||||
|
|
@ -29,16 +29,27 @@ class TestAgentShouldExit:
|
|||
assert ctx.agent_should_exit is True
|
||||
|
||||
def test_propagation_to_parent_context(self):
|
||||
"""Test that mark_should_exit propagates to parent contexts."""
|
||||
"""Test that mark_should_exit propagates to parent contexts when specified."""
|
||||
parent = AgentContext()
|
||||
child = AgentContext(parent_context=parent)
|
||||
|
||||
# Mark child as should exit
|
||||
child.mark_should_exit()
|
||||
# Mark child as should exit with propagation to all parents
|
||||
child.mark_should_exit(propagation_depth=None)
|
||||
|
||||
# Verify both child and parent are marked
|
||||
assert child.agent_should_exit is True
|
||||
assert parent.agent_should_exit is True
|
||||
|
||||
# Reset for the next test
|
||||
parent.agent_should_exit = False
|
||||
child.agent_should_exit = False
|
||||
|
||||
# Test default behavior (propagation_depth=0)
|
||||
child.mark_should_exit()
|
||||
|
||||
# Verify only child is marked
|
||||
assert child.agent_should_exit is True
|
||||
assert parent.agent_should_exit is False
|
||||
|
||||
def test_nested_context_manager_propagation(self):
|
||||
"""Test propagation with nested context managers."""
|
||||
|
|
@ -48,9 +59,139 @@ class TestAgentShouldExit:
|
|||
assert outer.agent_should_exit is False
|
||||
assert inner.agent_should_exit is False
|
||||
|
||||
# Mark inner as should exit
|
||||
inner.mark_should_exit()
|
||||
# Mark inner as should exit with explicit propagation to all parents
|
||||
inner.mark_should_exit(propagation_depth=None)
|
||||
|
||||
# Both should now be True
|
||||
assert inner.agent_should_exit is True
|
||||
assert outer.agent_should_exit is True
|
||||
|
||||
# Test default behavior (propagation_depth=0)
|
||||
with agent_context() as outer:
|
||||
with agent_context() as inner:
|
||||
# Initially both should be False
|
||||
assert outer.agent_should_exit is False
|
||||
assert inner.agent_should_exit is False
|
||||
|
||||
# Mark inner as should exit with default propagation
|
||||
inner.mark_should_exit()
|
||||
|
||||
# Only inner should be True
|
||||
assert inner.agent_should_exit is True
|
||||
assert outer.agent_should_exit is False
|
||||
|
||||
def test_mark_should_exit_propagation_depth(self):
|
||||
"""Test that mark_should_exit respects propagation depth."""
|
||||
# Create a hierarchy of contexts: ctx1 (root) -> ctx2 -> ctx3
|
||||
ctx1 = AgentContext()
|
||||
ctx2 = AgentContext(parent_context=ctx1)
|
||||
ctx3 = AgentContext(parent_context=ctx2)
|
||||
|
||||
# Test case 1: propagation_depth=0 (only marks current context)
|
||||
ctx3.mark_should_exit(propagation_depth=0)
|
||||
assert ctx3.agent_should_exit is True
|
||||
assert ctx2.agent_should_exit is False
|
||||
assert ctx1.agent_should_exit is False
|
||||
|
||||
# Reset all contexts
|
||||
ctx1.agent_should_exit = False
|
||||
ctx2.agent_should_exit = False
|
||||
ctx3.agent_should_exit = False
|
||||
|
||||
# Test case 2: propagation_depth=1 (marks current context and immediate parent)
|
||||
ctx3.mark_should_exit(propagation_depth=1)
|
||||
assert ctx3.agent_should_exit is True
|
||||
assert ctx2.agent_should_exit is True
|
||||
assert ctx1.agent_should_exit is False
|
||||
|
||||
# Reset all contexts
|
||||
ctx1.agent_should_exit = False
|
||||
ctx2.agent_should_exit = False
|
||||
ctx3.agent_should_exit = False
|
||||
|
||||
# Test case 3: propagation_depth=2 (marks current context, parent, and grandparent)
|
||||
ctx3.mark_should_exit(propagation_depth=2)
|
||||
assert ctx3.agent_should_exit is True
|
||||
assert ctx2.agent_should_exit is True
|
||||
assert ctx1.agent_should_exit is True
|
||||
|
||||
# Reset all contexts
|
||||
ctx1.agent_should_exit = False
|
||||
ctx2.agent_should_exit = False
|
||||
ctx3.agent_should_exit = False
|
||||
|
||||
# Test case 4: propagation_depth=None (marks all contexts)
|
||||
ctx3.mark_should_exit(propagation_depth=None)
|
||||
assert ctx3.agent_should_exit is True
|
||||
assert ctx2.agent_should_exit is True
|
||||
assert ctx1.agent_should_exit is True
|
||||
|
||||
# Reset all contexts
|
||||
ctx1.agent_should_exit = False
|
||||
ctx2.agent_should_exit = False
|
||||
ctx3.agent_should_exit = False
|
||||
|
||||
# Test case 5: default behavior (propagation_depth=0)
|
||||
ctx3.mark_should_exit() # Default is now 0
|
||||
assert ctx3.agent_should_exit is True
|
||||
assert ctx2.agent_should_exit is False
|
||||
assert ctx1.agent_should_exit is False
|
||||
|
||||
def test_helper_mark_should_exit_propagation_depth(self):
|
||||
"""Test that helper mark_should_exit function respects propagation depth."""
|
||||
# Create a hierarchy of contexts
|
||||
ctx1 = AgentContext()
|
||||
ctx2 = AgentContext(parent_context=ctx1)
|
||||
|
||||
# Test with agent_context to set the current context
|
||||
with agent_context(ctx2) as current_ctx:
|
||||
# Test case 1: propagation_depth=0 (only marks current context)
|
||||
mark_should_exit(propagation_depth=0)
|
||||
assert current_ctx.agent_should_exit is True
|
||||
assert ctx2.agent_should_exit is False # The context manager creates a new context
|
||||
assert ctx1.agent_should_exit is False
|
||||
|
||||
# Reset for the next test
|
||||
ctx1.agent_should_exit = False
|
||||
ctx2.agent_should_exit = False
|
||||
|
||||
with agent_context(ctx2) as current_ctx:
|
||||
# Test case 2: propagation_depth=1 (marks current context and immediate parent)
|
||||
mark_should_exit(propagation_depth=1)
|
||||
assert current_ctx.agent_should_exit is True
|
||||
# The current_ctx's parent is ctx2, so it should be marked
|
||||
assert ctx2.agent_should_exit is True
|
||||
assert ctx1.agent_should_exit is False
|
||||
|
||||
# Reset for the next test
|
||||
ctx1.agent_should_exit = False
|
||||
ctx2.agent_should_exit = False
|
||||
|
||||
with agent_context(ctx2) as current_ctx:
|
||||
# Test case 3: propagation_depth=2 (marks current context, parent, and grandparent)
|
||||
mark_should_exit(propagation_depth=2)
|
||||
assert current_ctx.agent_should_exit is True
|
||||
assert ctx2.agent_should_exit is True
|
||||
assert ctx1.agent_should_exit is True
|
||||
|
||||
# Reset for the next test
|
||||
ctx1.agent_should_exit = False
|
||||
ctx2.agent_should_exit = False
|
||||
|
||||
with agent_context(ctx2) as current_ctx:
|
||||
# Test case 4: propagation_depth=None (marks all contexts)
|
||||
mark_should_exit(propagation_depth=None)
|
||||
assert current_ctx.agent_should_exit is True
|
||||
assert ctx2.agent_should_exit is True
|
||||
assert ctx1.agent_should_exit is True
|
||||
|
||||
# Reset for the next test
|
||||
ctx1.agent_should_exit = False
|
||||
ctx2.agent_should_exit = False
|
||||
|
||||
with agent_context(ctx2) as current_ctx:
|
||||
# Test case 5: default behavior (propagation_depth=0)
|
||||
mark_should_exit() # Default is now 0
|
||||
assert current_ctx.agent_should_exit is True
|
||||
assert ctx2.agent_should_exit is False
|
||||
assert ctx1.agent_should_exit is False
|
||||
|
|
|
|||
|
|
@ -277,7 +277,7 @@ def plan_implementation_completed(message: str) -> str:
|
|||
Args:
|
||||
message: Message explaining how the implementation plan was completed
|
||||
"""
|
||||
mark_should_exit()
|
||||
mark_should_exit(propagation_depth=1)
|
||||
mark_plan_completed(message)
|
||||
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
|
||||
log_work_event(f"Completed implementation:\n\n{message}")
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ class TestExitPropagation:
|
|||
"""Test cases for the agent_should_exit flag propagation."""
|
||||
|
||||
def test_mark_should_exit_propagation(self):
|
||||
"""Test that mark_should_exit propagates to parent contexts."""
|
||||
"""Test that mark_should_exit propagates to parent contexts when requested."""
|
||||
parent = AgentContext()
|
||||
child = AgentContext(parent_context=parent)
|
||||
|
||||
|
|
@ -128,15 +128,26 @@ class TestExitPropagation:
|
|||
assert parent.agent_should_exit is False
|
||||
assert child.agent_should_exit is False
|
||||
|
||||
# Mark the child context as should exit
|
||||
child.mark_should_exit()
|
||||
# Test with explicit propagation to all parents
|
||||
child.mark_should_exit(propagation_depth=None)
|
||||
|
||||
# Both child and parent should now have agent_should_exit as True
|
||||
assert child.agent_should_exit is True
|
||||
assert parent.agent_should_exit is True
|
||||
|
||||
# Reset for next test
|
||||
parent.agent_should_exit = False
|
||||
child.agent_should_exit = False
|
||||
|
||||
# Test default behavior (no propagation)
|
||||
child.mark_should_exit()
|
||||
|
||||
# Only child should have agent_should_exit as True
|
||||
assert child.agent_should_exit is True
|
||||
assert parent.agent_should_exit is False
|
||||
|
||||
def test_nested_should_exit_propagation(self):
|
||||
"""Test that mark_should_exit propagates through multiple levels of parent contexts."""
|
||||
"""Test that mark_should_exit propagates through multiple levels of parent contexts when requested."""
|
||||
grandparent = AgentContext()
|
||||
parent = AgentContext(parent_context=grandparent)
|
||||
child = AgentContext(parent_context=parent)
|
||||
|
|
@ -146,28 +157,55 @@ class TestExitPropagation:
|
|||
assert parent.agent_should_exit is False
|
||||
assert child.agent_should_exit is False
|
||||
|
||||
# Mark the child context as should exit
|
||||
child.mark_should_exit()
|
||||
# Test with explicit propagation to all parents
|
||||
child.mark_should_exit(propagation_depth=None)
|
||||
|
||||
# All contexts should now have agent_should_exit as True
|
||||
assert child.agent_should_exit is True
|
||||
assert parent.agent_should_exit is True
|
||||
assert grandparent.agent_should_exit is True
|
||||
|
||||
# Reset for next test
|
||||
grandparent.agent_should_exit = False
|
||||
parent.agent_should_exit = False
|
||||
child.agent_should_exit = False
|
||||
|
||||
# Test default behavior (no propagation)
|
||||
child.mark_should_exit()
|
||||
|
||||
# Only child should have agent_should_exit as True
|
||||
assert child.agent_should_exit is True
|
||||
assert parent.agent_should_exit is False
|
||||
assert grandparent.agent_should_exit is False
|
||||
|
||||
def test_context_manager_should_exit_propagation(self):
|
||||
"""Test that mark_should_exit propagates when using context managers."""
|
||||
"""Test that mark_should_exit propagates when using context managers when requested."""
|
||||
with agent_context() as outer:
|
||||
with agent_context() as inner:
|
||||
# Initially both contexts should have agent_should_exit as False
|
||||
assert outer.agent_should_exit is False
|
||||
assert inner.agent_should_exit is False
|
||||
|
||||
# Mark the inner context as should exit
|
||||
inner.mark_should_exit()
|
||||
# Test with explicit propagation to all parents
|
||||
inner.mark_should_exit(propagation_depth=None)
|
||||
|
||||
# Both inner and outer should now have agent_should_exit as True
|
||||
assert inner.agent_should_exit is True
|
||||
assert outer.agent_should_exit is True
|
||||
|
||||
# Test default behavior (no propagation)
|
||||
with agent_context() as outer:
|
||||
with agent_context() as inner:
|
||||
# Initially both contexts should have agent_should_exit as False
|
||||
assert outer.agent_should_exit is False
|
||||
assert inner.agent_should_exit is False
|
||||
|
||||
# Mark the inner context as should exit with default propagation
|
||||
inner.mark_should_exit()
|
||||
|
||||
# Only inner should have agent_should_exit as True
|
||||
assert inner.agent_should_exit is True
|
||||
assert outer.agent_should_exit is False
|
||||
|
||||
|
||||
class TestCrashPropagation:
|
||||
|
|
@ -344,10 +382,24 @@ class TestUtilityFunctions:
|
|||
# Initially both contexts should have agent_should_exit as False
|
||||
assert should_exit() is False
|
||||
|
||||
# Mark the current context (inner) as should exit
|
||||
mark_should_exit()
|
||||
# Test with explicit propagation to all parents
|
||||
mark_should_exit(propagation_depth=None)
|
||||
|
||||
# Both inner and outer should now have agent_should_exit as True
|
||||
assert should_exit() is True
|
||||
assert inner.agent_should_exit is True
|
||||
assert outer.agent_should_exit is True
|
||||
|
||||
# Test default behavior (no propagation)
|
||||
with agent_context() as outer:
|
||||
with agent_context() as inner:
|
||||
# Initially both contexts should have agent_should_exit as False
|
||||
assert should_exit() is False
|
||||
|
||||
# Mark the current context (inner) as should exit with default propagation
|
||||
mark_should_exit()
|
||||
|
||||
# Only inner should have agent_should_exit as True
|
||||
assert should_exit() is True
|
||||
assert inner.agent_should_exit is True
|
||||
assert outer.agent_should_exit is False
|
||||
|
|
|
|||
|
|
@ -27,8 +27,8 @@ def test_ciayn_agent_stream_respects_should_exit():
|
|||
|
||||
# Test stream exits when should_exit is set
|
||||
with agent_context() as ctx:
|
||||
# Set should_exit
|
||||
mark_should_exit()
|
||||
# Set should_exit with propagation to all parents
|
||||
mark_should_exit(propagation_depth=None)
|
||||
|
||||
# Verify should_exit is set
|
||||
assert should_exit()
|
||||
|
|
@ -80,8 +80,8 @@ def test_ciayn_agent_bundled_tools_respects_should_exit():
|
|||
# Set up test messages
|
||||
messages = {"messages": [HumanMessage(content="test")]}
|
||||
|
||||
# Set should_exit before calling stream
|
||||
mark_should_exit()
|
||||
# Set should_exit with propagation to all parents
|
||||
mark_should_exit(propagation_depth=None)
|
||||
|
||||
# Call stream
|
||||
generator = agent.stream(messages)
|
||||
|
|
@ -128,7 +128,7 @@ def test_ciayn_agent_single_tool_respects_should_exit():
|
|||
def execute_and_exit(*args, **kwargs):
|
||||
# Set should_exit flag
|
||||
with agent_context() as ctx:
|
||||
mark_should_exit()
|
||||
mark_should_exit(propagation_depth=None)
|
||||
return "Tool executed"
|
||||
|
||||
# Override the mock tool to set should_exit
|
||||
|
|
@ -176,8 +176,8 @@ def test_ciayn_agent_execute_tool_respects_should_exit():
|
|||
|
||||
# Test _execute_tool exits when should_exit is set
|
||||
with agent_context() as ctx:
|
||||
# Set should_exit
|
||||
mark_should_exit()
|
||||
# Set should_exit with propagation to all parents
|
||||
mark_should_exit(propagation_depth=None)
|
||||
|
||||
# Call _execute_tool
|
||||
message = HumanMessage(content="test_tool()")
|
||||
|
|
|
|||
Loading…
Reference in New Issue