improve agent crash detection

This commit is contained in:
AI Christianson 2025-02-27 11:41:36 -05:00
parent 9202cf0d6d
commit 346e22b5cb
6 changed files with 456 additions and 5 deletions

View File

@ -25,6 +25,8 @@ class AgentContext:
self.plan_completed = False
self.completion_message = ""
self.agent_should_exit = False
self.agent_has_crashed = False
self.agent_crashed_message = None
# Note: Completion flags (task_completed, plan_completed, completion_message,
# agent_should_exit) are no longer inherited from parent contexts
@ -64,6 +66,25 @@ class AgentContext:
# Propagate to parent context if it exists
if self.parent:
self.parent.mark_should_exit()
def mark_agent_crashed(self, message: str) -> None:
"""Mark the agent as crashed with the given message.
Unlike exit state, crash state does not propagate to parent contexts.
Args:
message: Error message explaining the crash
"""
self.agent_has_crashed = True
self.agent_crashed_message = message
def is_crashed(self) -> bool:
"""Check if the agent has crashed.
Returns:
True if the agent has crashed, False otherwise
"""
return self.agent_has_crashed
@property
def is_completed(self) -> bool:
@ -176,4 +197,35 @@ def mark_should_exit() -> None:
"""Mark that the agent should exit execution."""
context = get_current_context()
if context:
context.mark_should_exit()
context.mark_should_exit()
def is_crashed() -> bool:
"""Check if the current agent has crashed.
Returns:
True if the current agent has crashed, False otherwise
"""
context = get_current_context()
return context.is_crashed() if context else False
def mark_agent_crashed(message: str) -> None:
"""Mark the current agent as crashed with the given message.
Args:
message: Error message explaining the crash
"""
context = get_current_context()
if context:
context.mark_agent_crashed(message)
def get_crash_message() -> Optional[str]:
"""Get the crash message from the current context.
Returns:
The crash message or None if the agent has not crashed
"""
context = get_current_context()
return context.agent_crashed_message if context and context.is_crashed() else None

View File

@ -933,6 +933,14 @@ def run_agent_with_retry(
for attempt in range(max_retries):
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt()
# Check if the agent has crashed before attempting to run it
from ra_aid.agent_context import is_crashed, get_crash_message
if is_crashed():
crash_message = get_crash_message()
logger.error("Agent has crashed: %s", crash_message)
return f"Agent has crashed: {crash_message}"
try:
_run_agent_stream(agent, msg_list, config)
if fallback_handler:
@ -950,6 +958,15 @@ def run_agent_with_retry(
logger.debug("Agent run completed successfully")
return "Agent run completed successfully"
except ToolExecutionError as e:
# Check if this is a BadRequestError (HTTP 400) which is unretryable
error_str = str(e).lower()
if "400" in error_str or "bad request" in error_str:
from ra_aid.agent_context import mark_agent_crashed
crash_message = f"Unretryable error: {str(e)}"
mark_agent_crashed(crash_message)
logger.error("Agent has crashed: %s", crash_message)
return f"Agent has crashed: {crash_message}"
_handle_fallback_response(e, fallback_handler, agent, msg_list)
continue
except FallbackToolExecutionError as e:
@ -965,6 +982,15 @@ def run_agent_with_retry(
APIError,
ValueError,
) as e:
# Check if this is a BadRequestError (HTTP 400) which is unretryable
error_str = str(e).lower()
if ("400" in error_str or "bad request" in error_str) and isinstance(e, APIError):
from ra_aid.agent_context import mark_agent_crashed
crash_message = f"Unretryable API error: {str(e)}"
mark_agent_crashed(crash_message)
logger.error("Agent has crashed: %s", crash_message)
return f"Agent has crashed: {crash_message}"
_handle_api_error(e, attempt, max_retries, base_delay)
finally:
_decrement_agent_depth()

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, List, Union
from langchain_core.tools import tool
from rich.console import Console
from ra_aid.agent_context import get_completion_message, reset_completion_flags
from ra_aid.agent_context import get_completion_message, get_crash_message, is_crashed, reset_completion_flags
from ra_aid.console.formatting import print_error
from ra_aid.exceptions import AgentInterrupt
from ra_aid.tools.memory import _global_memory
@ -301,13 +301,19 @@ def request_task_implementation(task_spec: str) -> str:
# Clear completion state
reset_completion_flags()
# Check if the agent has crashed
agent_crashed = is_crashed()
crash_message = get_crash_message() if agent_crashed else None
response_data = {
"key_facts": get_memory_value("key_facts"),
"related_files": get_related_files(),
"key_snippets": get_memory_value("key_snippets"),
"completion_message": completion_message,
"success": success,
"success": success and not agent_crashed,
"reason": reason,
"agent_crashed": agent_crashed,
"crash_message": crash_message,
}
if work_log is not None:
response_data["work_log"] = work_log
@ -320,6 +326,10 @@ def request_task_implementation(task_spec: str) -> str:
if response_data.get("completion_message"):
markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}")
# Add crash information if applicable
if response_data.get("agent_crashed"):
markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}")
# Add success status
status = "Success" if response_data.get("success", False) else "Failed"
reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else ""
@ -401,13 +411,19 @@ def request_implementation(task_spec: str) -> str:
# Clear completion state
reset_completion_flags()
# Check if the agent has crashed
agent_crashed = is_crashed()
crash_message = get_crash_message() if agent_crashed else None
response_data = {
"completion_message": completion_message,
"key_facts": get_memory_value("key_facts"),
"related_files": get_related_files(),
"key_snippets": get_memory_value("key_snippets"),
"success": success,
"success": success and not agent_crashed,
"reason": reason,
"agent_crashed": agent_crashed,
"crash_message": crash_message,
}
if work_log is not None:
response_data["work_log"] = work_log
@ -420,6 +436,10 @@ def request_implementation(task_spec: str) -> str:
if response_data.get("completion_message"):
markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}")
# Add crash information if applicable
if response_data.get("agent_crashed"):
markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}")
# Add success status
status = "Success" if response_data.get("success", False) else "Failed"
reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else ""
@ -446,4 +466,4 @@ def request_implementation(task_spec: str) -> str:
# Join all parts into a single markdown string
markdown_output = "".join(markdown_parts)
return markdown_output
return markdown_output

View File

@ -15,6 +15,9 @@ from ra_aid.agent_context import (
get_completion_message,
mark_should_exit,
should_exit,
mark_agent_crashed,
is_crashed,
get_crash_message,
)
@ -169,6 +172,84 @@ class TestExitPropagation:
assert outer.agent_should_exit is True
class TestCrashPropagation:
"""Test cases for the agent_has_crashed flag non-propagation."""
def test_mark_agent_crashed_no_propagation(self):
"""Test that mark_agent_crashed does not propagate to parent contexts."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
# Initially both contexts should have agent_has_crashed as False
assert parent.is_crashed() is False
assert child.is_crashed() is False
# Mark the child context as crashed
child.mark_agent_crashed("Child crashed")
# Child should be crashed, but parent should not
assert child.is_crashed() is True
assert parent.is_crashed() is False
assert child.agent_crashed_message == "Child crashed"
assert parent.agent_crashed_message is None
def test_nested_crash_no_propagation(self):
"""Test that crash states don't propagate through multiple levels of parent contexts."""
grandparent = AgentContext()
parent = AgentContext(parent_context=grandparent)
child = AgentContext(parent_context=parent)
# Initially all contexts should have agent_has_crashed as False
assert grandparent.is_crashed() is False
assert parent.is_crashed() is False
assert child.is_crashed() is False
# Mark the child context as crashed
child.mark_agent_crashed("Child crashed")
# Only child should be crashed, parent and grandparent should not
assert child.is_crashed() is True
assert parent.is_crashed() is False
assert grandparent.is_crashed() is False
assert child.agent_crashed_message == "Child crashed"
assert parent.agent_crashed_message is None
assert grandparent.agent_crashed_message is None
def test_context_manager_crash_no_propagation(self):
"""Test that crash state doesn't propagate when using context managers."""
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have agent_has_crashed as False
assert outer.is_crashed() is False
assert inner.is_crashed() is False
# Mark the inner context as crashed
inner.mark_agent_crashed("Inner crashed")
# Inner should be crashed, but outer should not
assert inner.is_crashed() is True
assert outer.is_crashed() is False
assert inner.agent_crashed_message == "Inner crashed"
assert outer.agent_crashed_message is None
def test_crash_state_not_inherited(self):
"""Test that new child contexts don't inherit crash states from parent contexts."""
parent = AgentContext()
# Mark the parent as crashed
parent.mark_agent_crashed("Parent crashed")
assert parent.is_crashed() is True
# Create a child context with the crashed parent as parent_context
child = AgentContext(parent_context=parent)
# Child should not be crashed even though parent is
assert parent.is_crashed() is True
assert child.is_crashed() is False
assert parent.agent_crashed_message == "Parent crashed"
assert child.agent_crashed_message is None
class TestThreadIsolation:
"""Test thread isolation of context variables."""

View File

@ -426,3 +426,191 @@ def test_is_anthropic_claude():
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
assert not is_anthropic_claude({"provider": "other", "model": "claude-2"}) # Wrong provider
def test_run_agent_with_retry_checks_crash_status(monkeypatch):
"""Test that run_agent_with_retry checks for crash status at the beginning of each iteration."""
from ra_aid.agent_utils import run_agent_with_retry
from ra_aid.agent_context import agent_context, mark_agent_crashed
# Setup mocks for dependencies to isolate our test
dummy_agent = Mock()
# Track function calls
mock_calls = {"run_agent_stream": 0}
def mock_run_agent_stream(*args, **kwargs):
mock_calls["run_agent_stream"] += 1
def mock_setup_interrupt_handling():
return None
def mock_restore_interrupt_handling(handler):
pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_is_crashed():
return ctx.is_crashed() if ctx else False
def mock_get_crash_message():
return ctx.agent_crashed_message if ctx and ctx.is_crashed() else None
# Apply mocks
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
# First, run without a crash - agent should be run
with agent_context() as ctx:
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message)
result = run_agent_with_retry(dummy_agent, "test prompt", {})
assert mock_calls["run_agent_stream"] == 1
# Reset call counter
mock_calls["run_agent_stream"] = 0
# Now run with a crash - agent should not be run
with agent_context() as ctx:
mark_agent_crashed("Test crash message")
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message)
result = run_agent_with_retry(dummy_agent, "test prompt", {})
# Verify _run_agent_stream was not called
assert mock_calls["run_agent_stream"] == 0
# Verify the result contains the crash message
assert "Agent has crashed: Test crash message" in result
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch):
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
from ra_aid.agent_utils import run_agent_with_retry
from ra_aid.exceptions import ToolExecutionError
from ra_aid.agent_context import agent_context, is_crashed
# Setup mocks
dummy_agent = Mock()
# Track function calls and simulate BadRequestError
run_count = [0]
def mock_run_agent_stream(*args, **kwargs):
run_count[0] += 1
if run_count[0] == 1:
# First call throws a 400 BadRequestError
raise ToolExecutionError("400 Bad Request: Invalid input")
# If it's called again, it should run normally
def mock_setup_interrupt_handling():
return None
def mock_restore_interrupt_handling(handler):
pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_mark_agent_crashed(message):
ctx.agent_has_crashed = True
ctx.agent_crashed_message = message
def mock_is_crashed():
return ctx.is_crashed() if ctx else False
# Apply mocks
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
with agent_context() as ctx:
monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed)
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
result = run_agent_with_retry(dummy_agent, "test prompt", {})
# Verify the agent was only run once and not retried
assert run_count[0] == 1
# Verify the result contains the crash message
assert "Agent has crashed: Unretryable error" in result
# Verify the agent is marked as crashed
assert is_crashed()
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch):
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
# Import APIError from anthropic module and patch it on the agent_utils module
from anthropic import APIError as AnthropicAPIError
from ra_aid.agent_utils import run_agent_with_retry
from ra_aid.agent_context import agent_context, is_crashed
# Setup mocks
dummy_agent = Mock()
# Track function calls and simulate BadRequestError
run_count = [0]
# Create a mock APIError class that simulates Anthropic's APIError
class MockAPIError(Exception):
pass
def mock_run_agent_stream(*args, **kwargs):
run_count[0] += 1
if run_count[0] == 1:
# First call throws a 400 Bad Request APIError
mock_error = MockAPIError("400 Bad Request")
mock_error.__class__.__name__ = "APIError" # Make it look like Anthropic's APIError
raise mock_error
# If it's called again, it should run normally
def mock_setup_interrupt_handling():
return None
def mock_restore_interrupt_handling(handler):
pass
def mock_increment_agent_depth():
pass
def mock_decrement_agent_depth():
pass
def mock_mark_agent_crashed(message):
ctx.agent_has_crashed = True
ctx.agent_crashed_message = message
def mock_is_crashed():
return ctx.is_crashed() if ctx else False
# Apply mocks
monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream)
monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling)
monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth)
monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth)
monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None)
monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None)
monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError)
monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed)
monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed)
with agent_context() as ctx:
result = run_agent_with_retry(dummy_agent, "test prompt", {})
# Verify the agent was only run once and not retried
assert run_count[0] == 1
# Verify the result contains the crash message
assert "Agent has crashed: Unretryable API error" in result
# Verify the agent is marked as crashed
assert is_crashed()

View File

@ -0,0 +1,84 @@
"""Unit tests for crash propagation behavior in agent_context."""
import pytest
from ra_aid.agent_context import (
AgentContext,
agent_context,
mark_agent_crashed,
is_crashed,
get_crash_message,
)
class TestCrashPropagation:
"""Test cases for crash state propagation behavior."""
def test_mark_agent_crashed_no_propagation(self):
"""Test that mark_agent_crashed does not propagate to parent contexts."""
parent = AgentContext()
child = AgentContext(parent_context=parent)
# Initially both contexts should have is_crashed as False
assert parent.is_crashed() is False
assert child.is_crashed() is False
# Mark the child context as crashed
child.mark_agent_crashed("Child crashed")
# Child should be crashed but parent should not
assert child.is_crashed() is True
assert child.agent_crashed_message == "Child crashed"
assert parent.is_crashed() is False
assert parent.agent_crashed_message is None
def test_nested_crash_no_propagation(self):
"""Test that crash state doesn't propagate through multiple levels of parent contexts."""
grandparent = AgentContext()
parent = AgentContext(parent_context=grandparent)
child = AgentContext(parent_context=parent)
# Initially all contexts should have is_crashed as False
assert grandparent.is_crashed() is False
assert parent.is_crashed() is False
assert child.is_crashed() is False
# Mark the child context as crashed
child.mark_agent_crashed("Child crashed")
# Only child should be crashed
assert child.is_crashed() is True
assert parent.is_crashed() is False
assert grandparent.is_crashed() is False
def test_context_manager_crash_no_propagation(self):
"""Test that crash states don't propagate when using context managers."""
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have is_crashed as False
assert outer.is_crashed() is False
assert inner.is_crashed() is False
# Mark the inner context as crashed
inner.mark_agent_crashed("Inner crashed")
# Inner should be crashed but outer should not
assert inner.is_crashed() is True
assert outer.is_crashed() is False
def test_utility_functions_for_crash_state(self):
"""Test utility functions for crash state."""
with agent_context() as outer:
with agent_context() as inner:
# Initially both contexts should have is_crashed as False
assert is_crashed() is False
assert get_crash_message() is None
# Mark the current context (inner) as crashed
mark_agent_crashed("Utility function crash")
# Current context should be crashed but outer should not
assert is_crashed() is True
assert get_crash_message() == "Utility function crash"
assert inner.is_crashed() is True
assert outer.is_crashed() is False