improve agent crash detection
This commit is contained in:
parent
9202cf0d6d
commit
346e22b5cb
|
|
@ -25,6 +25,8 @@ class AgentContext:
|
|||
self.plan_completed = False
|
||||
self.completion_message = ""
|
||||
self.agent_should_exit = False
|
||||
self.agent_has_crashed = False
|
||||
self.agent_crashed_message = None
|
||||
|
||||
# Note: Completion flags (task_completed, plan_completed, completion_message,
|
||||
# agent_should_exit) are no longer inherited from parent contexts
|
||||
|
|
@ -64,6 +66,25 @@ class AgentContext:
|
|||
# Propagate to parent context if it exists
|
||||
if self.parent:
|
||||
self.parent.mark_should_exit()
|
||||
|
||||
def mark_agent_crashed(self, message: str) -> None:
|
||||
"""Mark the agent as crashed with the given message.
|
||||
|
||||
Unlike exit state, crash state does not propagate to parent contexts.
|
||||
|
||||
Args:
|
||||
message: Error message explaining the crash
|
||||
"""
|
||||
self.agent_has_crashed = True
|
||||
self.agent_crashed_message = message
|
||||
|
||||
def is_crashed(self) -> bool:
|
||||
"""Check if the agent has crashed.
|
||||
|
||||
Returns:
|
||||
True if the agent has crashed, False otherwise
|
||||
"""
|
||||
return self.agent_has_crashed
|
||||
|
||||
@property
|
||||
def is_completed(self) -> bool:
|
||||
|
|
@ -176,4 +197,35 @@ def mark_should_exit() -> None:
|
|||
"""Mark that the agent should exit execution."""
|
||||
context = get_current_context()
|
||||
if context:
|
||||
context.mark_should_exit()
|
||||
context.mark_should_exit()
|
||||
|
||||
|
||||
def is_crashed() -> bool:
|
||||
"""Check if the current agent has crashed.
|
||||
|
||||
Returns:
|
||||
True if the current agent has crashed, False otherwise
|
||||
"""
|
||||
context = get_current_context()
|
||||
return context.is_crashed() if context else False
|
||||
|
||||
|
||||
def mark_agent_crashed(message: str) -> None:
|
||||
"""Mark the current agent as crashed with the given message.
|
||||
|
||||
Args:
|
||||
message: Error message explaining the crash
|
||||
"""
|
||||
context = get_current_context()
|
||||
if context:
|
||||
context.mark_agent_crashed(message)
|
||||
|
||||
|
||||
def get_crash_message() -> Optional[str]:
|
||||
"""Get the crash message from the current context.
|
||||
|
||||
Returns:
|
||||
The crash message or None if the agent has not crashed
|
||||
"""
|
||||
context = get_current_context()
|
||||
return context.agent_crashed_message if context and context.is_crashed() else None
|
||||
|
|
@ -933,6 +933,14 @@ def run_agent_with_retry(
|
|||
for attempt in range(max_retries):
|
||||
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
||||
check_interrupt()
|
||||
|
||||
# Check if the agent has crashed before attempting to run it
|
||||
from ra_aid.agent_context import is_crashed, get_crash_message
|
||||
if is_crashed():
|
||||
crash_message = get_crash_message()
|
||||
logger.error("Agent has crashed: %s", crash_message)
|
||||
return f"Agent has crashed: {crash_message}"
|
||||
|
||||
try:
|
||||
_run_agent_stream(agent, msg_list, config)
|
||||
if fallback_handler:
|
||||
|
|
@ -950,6 +958,15 @@ def run_agent_with_retry(
|
|||
logger.debug("Agent run completed successfully")
|
||||
return "Agent run completed successfully"
|
||||
except ToolExecutionError as e:
|
||||
# Check if this is a BadRequestError (HTTP 400) which is unretryable
|
||||
error_str = str(e).lower()
|
||||
if "400" in error_str or "bad request" in error_str:
|
||||
from ra_aid.agent_context import mark_agent_crashed
|
||||
crash_message = f"Unretryable error: {str(e)}"
|
||||
mark_agent_crashed(crash_message)
|
||||
logger.error("Agent has crashed: %s", crash_message)
|
||||
return f"Agent has crashed: {crash_message}"
|
||||
|
||||
_handle_fallback_response(e, fallback_handler, agent, msg_list)
|
||||
continue
|
||||
except FallbackToolExecutionError as e:
|
||||
|
|
@ -965,6 +982,15 @@ def run_agent_with_retry(
|
|||
APIError,
|
||||
ValueError,
|
||||
) as e:
|
||||
# Check if this is a BadRequestError (HTTP 400) which is unretryable
|
||||
error_str = str(e).lower()
|
||||
if ("400" in error_str or "bad request" in error_str) and isinstance(e, APIError):
|
||||
from ra_aid.agent_context import mark_agent_crashed
|
||||
crash_message = f"Unretryable API error: {str(e)}"
|
||||
mark_agent_crashed(crash_message)
|
||||
logger.error("Agent has crashed: %s", crash_message)
|
||||
return f"Agent has crashed: {crash_message}"
|
||||
|
||||
_handle_api_error(e, attempt, max_retries, base_delay)
|
||||
finally:
|
||||
_decrement_agent_depth()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, List, Union
|
|||
from langchain_core.tools import tool
|
||||
from rich.console import Console
|
||||
|
||||
from ra_aid.agent_context import get_completion_message, reset_completion_flags
|
||||
from ra_aid.agent_context import get_completion_message, get_crash_message, is_crashed, reset_completion_flags
|
||||
from ra_aid.console.formatting import print_error
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
|
@ -301,13 +301,19 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
# Clear completion state
|
||||
reset_completion_flags()
|
||||
|
||||
# Check if the agent has crashed
|
||||
agent_crashed = is_crashed()
|
||||
crash_message = get_crash_message() if agent_crashed else None
|
||||
|
||||
response_data = {
|
||||
"key_facts": get_memory_value("key_facts"),
|
||||
"related_files": get_related_files(),
|
||||
"key_snippets": get_memory_value("key_snippets"),
|
||||
"completion_message": completion_message,
|
||||
"success": success,
|
||||
"success": success and not agent_crashed,
|
||||
"reason": reason,
|
||||
"agent_crashed": agent_crashed,
|
||||
"crash_message": crash_message,
|
||||
}
|
||||
if work_log is not None:
|
||||
response_data["work_log"] = work_log
|
||||
|
|
@ -320,6 +326,10 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
if response_data.get("completion_message"):
|
||||
markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}")
|
||||
|
||||
# Add crash information if applicable
|
||||
if response_data.get("agent_crashed"):
|
||||
markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}")
|
||||
|
||||
# Add success status
|
||||
status = "Success" if response_data.get("success", False) else "Failed"
|
||||
reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else ""
|
||||
|
|
@ -401,13 +411,19 @@ def request_implementation(task_spec: str) -> str:
|
|||
# Clear completion state
|
||||
reset_completion_flags()
|
||||
|
||||
# Check if the agent has crashed
|
||||
agent_crashed = is_crashed()
|
||||
crash_message = get_crash_message() if agent_crashed else None
|
||||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
"key_facts": get_memory_value("key_facts"),
|
||||
"related_files": get_related_files(),
|
||||
"key_snippets": get_memory_value("key_snippets"),
|
||||
"success": success,
|
||||
"success": success and not agent_crashed,
|
||||
"reason": reason,
|
||||
"agent_crashed": agent_crashed,
|
||||
"crash_message": crash_message,
|
||||
}
|
||||
if work_log is not None:
|
||||
response_data["work_log"] = work_log
|
||||
|
|
@ -420,6 +436,10 @@ def request_implementation(task_spec: str) -> str:
|
|||
if response_data.get("completion_message"):
|
||||
markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}")
|
||||
|
||||
# Add crash information if applicable
|
||||
if response_data.get("agent_crashed"):
|
||||
markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}")
|
||||
|
||||
# Add success status
|
||||
status = "Success" if response_data.get("success", False) else "Failed"
|
||||
reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else ""
|
||||
|
|
@ -446,4 +466,4 @@ def request_implementation(task_spec: str) -> str:
|
|||
# Join all parts into a single markdown string
|
||||
markdown_output = "".join(markdown_parts)
|
||||
|
||||
return markdown_output
|
||||
return markdown_output
|
||||
|
|
@ -15,6 +15,9 @@ from ra_aid.agent_context import (
|
|||
get_completion_message,
|
||||
mark_should_exit,
|
||||
should_exit,
|
||||
mark_agent_crashed,
|
||||
is_crashed,
|
||||
get_crash_message,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -169,6 +172,84 @@ class TestExitPropagation:
|
|||
assert outer.agent_should_exit is True
|
||||
|
||||
|
||||
class TestCrashPropagation:
|
||||
"""Test cases for the agent_has_crashed flag non-propagation."""
|
||||
|
||||
def test_mark_agent_crashed_no_propagation(self):
|
||||
"""Test that mark_agent_crashed does not propagate to parent contexts."""
|
||||
parent = AgentContext()
|
||||
child = AgentContext(parent_context=parent)
|
||||
|
||||
# Initially both contexts should have agent_has_crashed as False
|
||||
assert parent.is_crashed() is False
|
||||
assert child.is_crashed() is False
|
||||
|
||||
# Mark the child context as crashed
|
||||
child.mark_agent_crashed("Child crashed")
|
||||
|
||||
# Child should be crashed, but parent should not
|
||||
assert child.is_crashed() is True
|
||||
assert parent.is_crashed() is False
|
||||
assert child.agent_crashed_message == "Child crashed"
|
||||
assert parent.agent_crashed_message is None
|
||||
|
||||
def test_nested_crash_no_propagation(self):
|
||||
"""Test that crash states don't propagate through multiple levels of parent contexts."""
|
||||
grandparent = AgentContext()
|
||||
parent = AgentContext(parent_context=grandparent)
|
||||
child = AgentContext(parent_context=parent)
|
||||
|
||||
# Initially all contexts should have agent_has_crashed as False
|
||||
assert grandparent.is_crashed() is False
|
||||
assert parent.is_crashed() is False
|
||||
assert child.is_crashed() is False
|
||||
|
||||
# Mark the child context as crashed
|
||||
child.mark_agent_crashed("Child crashed")
|
||||
|
||||
# Only child should be crashed, parent and grandparent should not
|
||||
assert child.is_crashed() is True
|
||||
assert parent.is_crashed() is False
|
||||
assert grandparent.is_crashed() is False
|
||||
assert child.agent_crashed_message == "Child crashed"
|
||||
assert parent.agent_crashed_message is None
|
||||
assert grandparent.agent_crashed_message is None
|
||||
|
||||
def test_context_manager_crash_no_propagation(self):
|
||||
"""Test that crash state doesn't propagate when using context managers."""
|
||||
with agent_context() as outer:
|
||||
with agent_context() as inner:
|
||||
# Initially both contexts should have agent_has_crashed as False
|
||||
assert outer.is_crashed() is False
|
||||
assert inner.is_crashed() is False
|
||||
|
||||
# Mark the inner context as crashed
|
||||
inner.mark_agent_crashed("Inner crashed")
|
||||
|
||||
# Inner should be crashed, but outer should not
|
||||
assert inner.is_crashed() is True
|
||||
assert outer.is_crashed() is False
|
||||
assert inner.agent_crashed_message == "Inner crashed"
|
||||
assert outer.agent_crashed_message is None
|
||||
|
||||
def test_crash_state_not_inherited(self):
|
||||
"""Test that new child contexts don't inherit crash states from parent contexts."""
|
||||
parent = AgentContext()
|
||||
|
||||
# Mark the parent as crashed
|
||||
parent.mark_agent_crashed("Parent crashed")
|
||||
assert parent.is_crashed() is True
|
||||
|
||||
# Create a child context with the crashed parent as parent_context
|
||||
child = AgentContext(parent_context=parent)
|
||||
|
||||
# Child should not be crashed even though parent is
|
||||
assert parent.is_crashed() is True
|
||||
assert child.is_crashed() is False
|
||||
assert parent.agent_crashed_message == "Parent crashed"
|
||||
assert child.agent_crashed_message is None
|
||||
|
||||
|
||||
class TestThreadIsolation:
|
||||
"""Test thread isolation of context variables."""
|
||||
|
||||
|
|
|
|||
|
|
@ -426,3 +426,191 @@ def test_is_anthropic_claude():
|
|||
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
|
||||
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
|
||||
assert not is_anthropic_claude({"provider": "other", "model": "claude-2"}) # Wrong provider
|
||||
|
||||
|
||||
def test_run_agent_with_retry_checks_crash_status(monkeypatch):
|
||||
"""Test that run_agent_with_retry checks for crash status at the beginning of each iteration."""
|
||||
from ra_aid.agent_utils import run_agent_with_retry
|
||||
from ra_aid.agent_context import agent_context, mark_agent_crashed
|
||||
|
||||
# Setup mocks for dependencies to isolate our test
|
||||
dummy_agent = Mock()
|
||||
|
||||
# Track function calls
|
||||
mock_calls = {"run_agent_stream": 0}
|
||||
|
||||
def mock_run_agent_stream(*args, **kwargs):
|
||||
mock_calls["run_agent_stream"] += 1
|
||||
|
||||
def mock_setup_interrupt_handling():
|
||||
return None
|
||||
|
||||
def mock_restore_interrupt_handling(handler):
|
||||
pass
|
||||
|
||||
def mock_increment_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_decrement_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_is_crashed():
|
||||
return ctx.is_crashed() if ctx else False
|
||||
|
||||
def mock_get_crash_message():
|
||||
return ctx.agent_crashed_message if ctx and ctx.is_crashed() else None
|
||||
|
||||
# Apply mocks
|
||||
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||
|
||||
# First, run without a crash - agent should be run
|
||||
with agent_context() as ctx:
|
||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||
monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message)
|
||||
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
||||
assert mock_calls["run_agent_stream"] == 1
|
||||
|
||||
# Reset call counter
|
||||
mock_calls["run_agent_stream"] = 0
|
||||
|
||||
# Now run with a crash - agent should not be run
|
||||
with agent_context() as ctx:
|
||||
mark_agent_crashed("Test crash message")
|
||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||
monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message)
|
||||
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
||||
# Verify _run_agent_stream was not called
|
||||
assert mock_calls["run_agent_stream"] == 0
|
||||
# Verify the result contains the crash message
|
||||
assert "Agent has crashed: Test crash message" in result
|
||||
|
||||
|
||||
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
|
||||
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
|
||||
from ra_aid.agent_utils import run_agent_with_retry
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
from ra_aid.agent_context import agent_context, is_crashed
|
||||
|
||||
# Setup mocks
|
||||
dummy_agent = Mock()
|
||||
|
||||
# Track function calls and simulate BadRequestError
|
||||
run_count = [0]
|
||||
|
||||
def mock_run_agent_stream(*args, **kwargs):
|
||||
run_count[0] += 1
|
||||
if run_count[0] == 1:
|
||||
# First call throws a 400 BadRequestError
|
||||
raise ToolExecutionError("400 Bad Request: Invalid input")
|
||||
# If it's called again, it should run normally
|
||||
|
||||
def mock_setup_interrupt_handling():
|
||||
return None
|
||||
|
||||
def mock_restore_interrupt_handling(handler):
|
||||
pass
|
||||
|
||||
def mock_increment_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_decrement_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_mark_agent_crashed(message):
|
||||
ctx.agent_has_crashed = True
|
||||
ctx.agent_crashed_message = message
|
||||
|
||||
def mock_is_crashed():
|
||||
return ctx.is_crashed() if ctx else False
|
||||
|
||||
# Apply mocks
|
||||
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||
|
||||
with agent_context() as ctx:
|
||||
monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed)
|
||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||
|
||||
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
||||
# Verify the agent was only run once and not retried
|
||||
assert run_count[0] == 1
|
||||
# Verify the result contains the crash message
|
||||
assert "Agent has crashed: Unretryable error" in result
|
||||
# Verify the agent is marked as crashed
|
||||
assert is_crashed()
|
||||
|
||||
|
||||
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
|
||||
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
|
||||
# Import APIError from anthropic module and patch it on the agent_utils module
|
||||
from anthropic import APIError as AnthropicAPIError
|
||||
from ra_aid.agent_utils import run_agent_with_retry
|
||||
from ra_aid.agent_context import agent_context, is_crashed
|
||||
|
||||
# Setup mocks
|
||||
dummy_agent = Mock()
|
||||
|
||||
# Track function calls and simulate BadRequestError
|
||||
run_count = [0]
|
||||
|
||||
# Create a mock APIError class that simulates Anthropic's APIError
|
||||
class MockAPIError(Exception):
|
||||
pass
|
||||
|
||||
def mock_run_agent_stream(*args, **kwargs):
|
||||
run_count[0] += 1
|
||||
if run_count[0] == 1:
|
||||
# First call throws a 400 Bad Request APIError
|
||||
mock_error = MockAPIError("400 Bad Request")
|
||||
mock_error.__class__.__name__ = "APIError" # Make it look like Anthropic's APIError
|
||||
raise mock_error
|
||||
# If it's called again, it should run normally
|
||||
|
||||
def mock_setup_interrupt_handling():
|
||||
return None
|
||||
|
||||
def mock_restore_interrupt_handling(handler):
|
||||
pass
|
||||
|
||||
def mock_increment_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_decrement_agent_depth():
|
||||
pass
|
||||
|
||||
def mock_mark_agent_crashed(message):
|
||||
ctx.agent_has_crashed = True
|
||||
ctx.agent_crashed_message = message
|
||||
|
||||
def mock_is_crashed():
|
||||
return ctx.is_crashed() if ctx else False
|
||||
|
||||
# Apply mocks
|
||||
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
|
||||
monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError)
|
||||
monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed)
|
||||
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
|
||||
|
||||
with agent_context() as ctx:
|
||||
result = run_agent_with_retry(dummy_agent, "test prompt", {})
|
||||
# Verify the agent was only run once and not retried
|
||||
assert run_count[0] == 1
|
||||
# Verify the result contains the crash message
|
||||
assert "Agent has crashed: Unretryable API error" in result
|
||||
# Verify the agent is marked as crashed
|
||||
assert is_crashed()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,84 @@
|
|||
"""Unit tests for crash propagation behavior in agent_context."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
AgentContext,
|
||||
agent_context,
|
||||
mark_agent_crashed,
|
||||
is_crashed,
|
||||
get_crash_message,
|
||||
)
|
||||
|
||||
|
||||
class TestCrashPropagation:
|
||||
"""Test cases for crash state propagation behavior."""
|
||||
|
||||
def test_mark_agent_crashed_no_propagation(self):
|
||||
"""Test that mark_agent_crashed does not propagate to parent contexts."""
|
||||
parent = AgentContext()
|
||||
child = AgentContext(parent_context=parent)
|
||||
|
||||
# Initially both contexts should have is_crashed as False
|
||||
assert parent.is_crashed() is False
|
||||
assert child.is_crashed() is False
|
||||
|
||||
# Mark the child context as crashed
|
||||
child.mark_agent_crashed("Child crashed")
|
||||
|
||||
# Child should be crashed but parent should not
|
||||
assert child.is_crashed() is True
|
||||
assert child.agent_crashed_message == "Child crashed"
|
||||
assert parent.is_crashed() is False
|
||||
assert parent.agent_crashed_message is None
|
||||
|
||||
def test_nested_crash_no_propagation(self):
|
||||
"""Test that crash state doesn't propagate through multiple levels of parent contexts."""
|
||||
grandparent = AgentContext()
|
||||
parent = AgentContext(parent_context=grandparent)
|
||||
child = AgentContext(parent_context=parent)
|
||||
|
||||
# Initially all contexts should have is_crashed as False
|
||||
assert grandparent.is_crashed() is False
|
||||
assert parent.is_crashed() is False
|
||||
assert child.is_crashed() is False
|
||||
|
||||
# Mark the child context as crashed
|
||||
child.mark_agent_crashed("Child crashed")
|
||||
|
||||
# Only child should be crashed
|
||||
assert child.is_crashed() is True
|
||||
assert parent.is_crashed() is False
|
||||
assert grandparent.is_crashed() is False
|
||||
|
||||
def test_context_manager_crash_no_propagation(self):
|
||||
"""Test that crash states don't propagate when using context managers."""
|
||||
with agent_context() as outer:
|
||||
with agent_context() as inner:
|
||||
# Initially both contexts should have is_crashed as False
|
||||
assert outer.is_crashed() is False
|
||||
assert inner.is_crashed() is False
|
||||
|
||||
# Mark the inner context as crashed
|
||||
inner.mark_agent_crashed("Inner crashed")
|
||||
|
||||
# Inner should be crashed but outer should not
|
||||
assert inner.is_crashed() is True
|
||||
assert outer.is_crashed() is False
|
||||
|
||||
def test_utility_functions_for_crash_state(self):
|
||||
"""Test utility functions for crash state."""
|
||||
with agent_context() as outer:
|
||||
with agent_context() as inner:
|
||||
# Initially both contexts should have is_crashed as False
|
||||
assert is_crashed() is False
|
||||
assert get_crash_message() is None
|
||||
|
||||
# Mark the current context (inner) as crashed
|
||||
mark_agent_crashed("Utility function crash")
|
||||
|
||||
# Current context should be crashed but outer should not
|
||||
assert is_crashed() is True
|
||||
assert get_crash_message() == "Utility function crash"
|
||||
assert inner.is_crashed() is True
|
||||
assert outer.is_crashed() is False
|
||||
Loading…
Reference in New Issue