From 346e22b5cbb3fbab2ac197a52b1e7f63fa3bdea0 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Thu, 27 Feb 2025 11:41:36 -0500 Subject: [PATCH] improve agent crash detection --- ra_aid/agent_context.py | 54 ++++++- ra_aid/agent_utils.py | 26 ++++ ra_aid/tools/agent.py | 28 +++- tests/ra_aid/test_agent_context.py | 81 +++++++++++ tests/ra_aid/test_agent_utils.py | 188 +++++++++++++++++++++++++ tests/ra_aid/test_crash_propagation.py | 84 +++++++++++ 6 files changed, 456 insertions(+), 5 deletions(-) create mode 100644 tests/ra_aid/test_crash_propagation.py diff --git a/ra_aid/agent_context.py b/ra_aid/agent_context.py index c8cd01b..e0579b2 100644 --- a/ra_aid/agent_context.py +++ b/ra_aid/agent_context.py @@ -25,6 +25,8 @@ class AgentContext: self.plan_completed = False self.completion_message = "" self.agent_should_exit = False + self.agent_has_crashed = False + self.agent_crashed_message = None # Note: Completion flags (task_completed, plan_completed, completion_message, # agent_should_exit) are no longer inherited from parent contexts @@ -64,6 +66,25 @@ class AgentContext: # Propagate to parent context if it exists if self.parent: self.parent.mark_should_exit() + + def mark_agent_crashed(self, message: str) -> None: + """Mark the agent as crashed with the given message. + + Unlike exit state, crash state does not propagate to parent contexts. + + Args: + message: Error message explaining the crash + """ + self.agent_has_crashed = True + self.agent_crashed_message = message + + def is_crashed(self) -> bool: + """Check if the agent has crashed. + + Returns: + True if the agent has crashed, False otherwise + """ + return self.agent_has_crashed @property def is_completed(self) -> bool: @@ -176,4 +197,35 @@ def mark_should_exit() -> None: """Mark that the agent should exit execution.""" context = get_current_context() if context: - context.mark_should_exit() \ No newline at end of file + context.mark_should_exit() + + +def is_crashed() -> bool: + """Check if the current agent has crashed. + + Returns: + True if the current agent has crashed, False otherwise + """ + context = get_current_context() + return context.is_crashed() if context else False + + +def mark_agent_crashed(message: str) -> None: + """Mark the current agent as crashed with the given message. + + Args: + message: Error message explaining the crash + """ + context = get_current_context() + if context: + context.mark_agent_crashed(message) + + +def get_crash_message() -> Optional[str]: + """Get the crash message from the current context. + + Returns: + The crash message or None if the agent has not crashed + """ + context = get_current_context() + return context.agent_crashed_message if context and context.is_crashed() else None \ No newline at end of file diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 8ce01dd..2447895 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -933,6 +933,14 @@ def run_agent_with_retry( for attempt in range(max_retries): logger.debug("Attempt %d/%d", attempt + 1, max_retries) check_interrupt() + + # Check if the agent has crashed before attempting to run it + from ra_aid.agent_context import is_crashed, get_crash_message + if is_crashed(): + crash_message = get_crash_message() + logger.error("Agent has crashed: %s", crash_message) + return f"Agent has crashed: {crash_message}" + try: _run_agent_stream(agent, msg_list, config) if fallback_handler: @@ -950,6 +958,15 @@ def run_agent_with_retry( logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: + # Check if this is a BadRequestError (HTTP 400) which is unretryable + error_str = str(e).lower() + if "400" in error_str or "bad request" in error_str: + from ra_aid.agent_context import mark_agent_crashed + crash_message = f"Unretryable error: {str(e)}" + mark_agent_crashed(crash_message) + logger.error("Agent has crashed: %s", crash_message) + return f"Agent has crashed: {crash_message}" + _handle_fallback_response(e, fallback_handler, agent, msg_list) continue except FallbackToolExecutionError as e: @@ -965,6 +982,15 @@ def run_agent_with_retry( APIError, ValueError, ) as e: + # Check if this is a BadRequestError (HTTP 400) which is unretryable + error_str = str(e).lower() + if ("400" in error_str or "bad request" in error_str) and isinstance(e, APIError): + from ra_aid.agent_context import mark_agent_crashed + crash_message = f"Unretryable API error: {str(e)}" + mark_agent_crashed(crash_message) + logger.error("Agent has crashed: %s", crash_message) + return f"Agent has crashed: {crash_message}" + _handle_api_error(e, attempt, max_retries, base_delay) finally: _decrement_agent_depth() diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 9cbc406..b73df12 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Union from langchain_core.tools import tool from rich.console import Console -from ra_aid.agent_context import get_completion_message, reset_completion_flags +from ra_aid.agent_context import get_completion_message, get_crash_message, is_crashed, reset_completion_flags from ra_aid.console.formatting import print_error from ra_aid.exceptions import AgentInterrupt from ra_aid.tools.memory import _global_memory @@ -301,13 +301,19 @@ def request_task_implementation(task_spec: str) -> str: # Clear completion state reset_completion_flags() + # Check if the agent has crashed + agent_crashed = is_crashed() + crash_message = get_crash_message() if agent_crashed else None + response_data = { "key_facts": get_memory_value("key_facts"), "related_files": get_related_files(), "key_snippets": get_memory_value("key_snippets"), "completion_message": completion_message, - "success": success, + "success": success and not agent_crashed, "reason": reason, + "agent_crashed": agent_crashed, + "crash_message": crash_message, } if work_log is not None: response_data["work_log"] = work_log @@ -320,6 +326,10 @@ def request_task_implementation(task_spec: str) -> str: if response_data.get("completion_message"): markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}") + # Add crash information if applicable + if response_data.get("agent_crashed"): + markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}") + # Add success status status = "Success" if response_data.get("success", False) else "Failed" reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else "" @@ -401,13 +411,19 @@ def request_implementation(task_spec: str) -> str: # Clear completion state reset_completion_flags() + # Check if the agent has crashed + agent_crashed = is_crashed() + crash_message = get_crash_message() if agent_crashed else None + response_data = { "completion_message": completion_message, "key_facts": get_memory_value("key_facts"), "related_files": get_related_files(), "key_snippets": get_memory_value("key_snippets"), - "success": success, + "success": success and not agent_crashed, "reason": reason, + "agent_crashed": agent_crashed, + "crash_message": crash_message, } if work_log is not None: response_data["work_log"] = work_log @@ -420,6 +436,10 @@ def request_implementation(task_spec: str) -> str: if response_data.get("completion_message"): markdown_parts.append(f"\n## Completion Message\n\n{response_data['completion_message']}") + # Add crash information if applicable + if response_data.get("agent_crashed"): + markdown_parts.append(f"\n## ⚠️ Agent Crashed ⚠️\n\n**Error:** {response_data.get('crash_message', 'Unknown error')}") + # Add success status status = "Success" if response_data.get("success", False) else "Failed" reason_text = f": {response_data.get('reason')}" if response_data.get("reason") else "" @@ -446,4 +466,4 @@ def request_implementation(task_spec: str) -> str: # Join all parts into a single markdown string markdown_output = "".join(markdown_parts) - return markdown_output + return markdown_output \ No newline at end of file diff --git a/tests/ra_aid/test_agent_context.py b/tests/ra_aid/test_agent_context.py index 07bef8e..7082146 100644 --- a/tests/ra_aid/test_agent_context.py +++ b/tests/ra_aid/test_agent_context.py @@ -15,6 +15,9 @@ from ra_aid.agent_context import ( get_completion_message, mark_should_exit, should_exit, + mark_agent_crashed, + is_crashed, + get_crash_message, ) @@ -169,6 +172,84 @@ class TestExitPropagation: assert outer.agent_should_exit is True +class TestCrashPropagation: + """Test cases for the agent_has_crashed flag non-propagation.""" + + def test_mark_agent_crashed_no_propagation(self): + """Test that mark_agent_crashed does not propagate to parent contexts.""" + parent = AgentContext() + child = AgentContext(parent_context=parent) + + # Initially both contexts should have agent_has_crashed as False + assert parent.is_crashed() is False + assert child.is_crashed() is False + + # Mark the child context as crashed + child.mark_agent_crashed("Child crashed") + + # Child should be crashed, but parent should not + assert child.is_crashed() is True + assert parent.is_crashed() is False + assert child.agent_crashed_message == "Child crashed" + assert parent.agent_crashed_message is None + + def test_nested_crash_no_propagation(self): + """Test that crash states don't propagate through multiple levels of parent contexts.""" + grandparent = AgentContext() + parent = AgentContext(parent_context=grandparent) + child = AgentContext(parent_context=parent) + + # Initially all contexts should have agent_has_crashed as False + assert grandparent.is_crashed() is False + assert parent.is_crashed() is False + assert child.is_crashed() is False + + # Mark the child context as crashed + child.mark_agent_crashed("Child crashed") + + # Only child should be crashed, parent and grandparent should not + assert child.is_crashed() is True + assert parent.is_crashed() is False + assert grandparent.is_crashed() is False + assert child.agent_crashed_message == "Child crashed" + assert parent.agent_crashed_message is None + assert grandparent.agent_crashed_message is None + + def test_context_manager_crash_no_propagation(self): + """Test that crash state doesn't propagate when using context managers.""" + with agent_context() as outer: + with agent_context() as inner: + # Initially both contexts should have agent_has_crashed as False + assert outer.is_crashed() is False + assert inner.is_crashed() is False + + # Mark the inner context as crashed + inner.mark_agent_crashed("Inner crashed") + + # Inner should be crashed, but outer should not + assert inner.is_crashed() is True + assert outer.is_crashed() is False + assert inner.agent_crashed_message == "Inner crashed" + assert outer.agent_crashed_message is None + + def test_crash_state_not_inherited(self): + """Test that new child contexts don't inherit crash states from parent contexts.""" + parent = AgentContext() + + # Mark the parent as crashed + parent.mark_agent_crashed("Parent crashed") + assert parent.is_crashed() is True + + # Create a child context with the crashed parent as parent_context + child = AgentContext(parent_context=parent) + + # Child should not be crashed even though parent is + assert parent.is_crashed() is True + assert child.is_crashed() is False + assert parent.agent_crashed_message == "Parent crashed" + assert child.agent_crashed_message is None + + class TestThreadIsolation: """Test thread isolation of context variables.""" diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index fca5fcd..094f642 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -426,3 +426,191 @@ def test_is_anthropic_claude(): assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider assert not is_anthropic_claude({"provider": "other", "model": "claude-2"}) # Wrong provider + + +def test_run_agent_with_retry_checks_crash_status(monkeypatch): + """Test that run_agent_with_retry checks for crash status at the beginning of each iteration.""" + from ra_aid.agent_utils import run_agent_with_retry + from ra_aid.agent_context import agent_context, mark_agent_crashed + + # Setup mocks for dependencies to isolate our test + dummy_agent = Mock() + + # Track function calls + mock_calls = {"run_agent_stream": 0} + + def mock_run_agent_stream(*args, **kwargs): + mock_calls["run_agent_stream"] += 1 + + def mock_setup_interrupt_handling(): + return None + + def mock_restore_interrupt_handling(handler): + pass + + def mock_increment_agent_depth(): + pass + + def mock_decrement_agent_depth(): + pass + + def mock_is_crashed(): + return ctx.is_crashed() if ctx else False + + def mock_get_crash_message(): + return ctx.agent_crashed_message if ctx and ctx.is_crashed() else None + + # Apply mocks + monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream) + monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling) + monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling) + monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth) + monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth) + monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None) + + # First, run without a crash - agent should be run + with agent_context() as ctx: + monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed) + monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message) + result = run_agent_with_retry(dummy_agent, "test prompt", {}) + assert mock_calls["run_agent_stream"] == 1 + + # Reset call counter + mock_calls["run_agent_stream"] = 0 + + # Now run with a crash - agent should not be run + with agent_context() as ctx: + mark_agent_crashed("Test crash message") + monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed) + monkeypatch.setattr("ra_aid.agent_context.get_crash_message", mock_get_crash_message) + result = run_agent_with_retry(dummy_agent, "test prompt", {}) + # Verify _run_agent_stream was not called + assert mock_calls["run_agent_stream"] == 0 + # Verify the result contains the crash message + assert "Agent has crashed: Test crash message" in result + + +def test_run_agent_with_retry_handles_badrequest_error(monkeypatch): + """Test that run_agent_with_retry properly handles BadRequestError as unretryable.""" + from ra_aid.agent_utils import run_agent_with_retry + from ra_aid.exceptions import ToolExecutionError + from ra_aid.agent_context import agent_context, is_crashed + + # Setup mocks + dummy_agent = Mock() + + # Track function calls and simulate BadRequestError + run_count = [0] + + def mock_run_agent_stream(*args, **kwargs): + run_count[0] += 1 + if run_count[0] == 1: + # First call throws a 400 BadRequestError + raise ToolExecutionError("400 Bad Request: Invalid input") + # If it's called again, it should run normally + + def mock_setup_interrupt_handling(): + return None + + def mock_restore_interrupt_handling(handler): + pass + + def mock_increment_agent_depth(): + pass + + def mock_decrement_agent_depth(): + pass + + def mock_mark_agent_crashed(message): + ctx.agent_has_crashed = True + ctx.agent_crashed_message = message + + def mock_is_crashed(): + return ctx.is_crashed() if ctx else False + + # Apply mocks + monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream) + monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling) + monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling) + monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth) + monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth) + monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None) + + with agent_context() as ctx: + monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed) + monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed) + + result = run_agent_with_retry(dummy_agent, "test prompt", {}) + # Verify the agent was only run once and not retried + assert run_count[0] == 1 + # Verify the result contains the crash message + assert "Agent has crashed: Unretryable error" in result + # Verify the agent is marked as crashed + assert is_crashed() + + +def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch): + """Test that run_agent_with_retry properly handles API BadRequestError as unretryable.""" + # Import APIError from anthropic module and patch it on the agent_utils module + from anthropic import APIError as AnthropicAPIError + from ra_aid.agent_utils import run_agent_with_retry + from ra_aid.agent_context import agent_context, is_crashed + + # Setup mocks + dummy_agent = Mock() + + # Track function calls and simulate BadRequestError + run_count = [0] + + # Create a mock APIError class that simulates Anthropic's APIError + class MockAPIError(Exception): + pass + + def mock_run_agent_stream(*args, **kwargs): + run_count[0] += 1 + if run_count[0] == 1: + # First call throws a 400 Bad Request APIError + mock_error = MockAPIError("400 Bad Request") + mock_error.__class__.__name__ = "APIError" # Make it look like Anthropic's APIError + raise mock_error + # If it's called again, it should run normally + + def mock_setup_interrupt_handling(): + return None + + def mock_restore_interrupt_handling(handler): + pass + + def mock_increment_agent_depth(): + pass + + def mock_decrement_agent_depth(): + pass + + def mock_mark_agent_crashed(message): + ctx.agent_has_crashed = True + ctx.agent_crashed_message = message + + def mock_is_crashed(): + return ctx.is_crashed() if ctx else False + + # Apply mocks + monkeypatch.setattr("ra_aid.agent_utils._run_agent_stream", mock_run_agent_stream) + monkeypatch.setattr("ra_aid.agent_utils._setup_interrupt_handling", mock_setup_interrupt_handling) + monkeypatch.setattr("ra_aid.agent_utils._restore_interrupt_handling", mock_restore_interrupt_handling) + monkeypatch.setattr("ra_aid.agent_utils._increment_agent_depth", mock_increment_agent_depth) + monkeypatch.setattr("ra_aid.agent_utils._decrement_agent_depth", mock_decrement_agent_depth) + monkeypatch.setattr("ra_aid.agent_utils.check_interrupt", lambda: None) + monkeypatch.setattr("ra_aid.agent_utils._handle_api_error", lambda *args: None) + monkeypatch.setattr("ra_aid.agent_utils.APIError", MockAPIError) + monkeypatch.setattr("ra_aid.agent_context.mark_agent_crashed", mock_mark_agent_crashed) + monkeypatch.setattr("ra_aid.agent_context.is_crashed", mock_is_crashed) + + with agent_context() as ctx: + result = run_agent_with_retry(dummy_agent, "test prompt", {}) + # Verify the agent was only run once and not retried + assert run_count[0] == 1 + # Verify the result contains the crash message + assert "Agent has crashed: Unretryable API error" in result + # Verify the agent is marked as crashed + assert is_crashed() diff --git a/tests/ra_aid/test_crash_propagation.py b/tests/ra_aid/test_crash_propagation.py new file mode 100644 index 0000000..5b58063 --- /dev/null +++ b/tests/ra_aid/test_crash_propagation.py @@ -0,0 +1,84 @@ +"""Unit tests for crash propagation behavior in agent_context.""" + +import pytest + +from ra_aid.agent_context import ( + AgentContext, + agent_context, + mark_agent_crashed, + is_crashed, + get_crash_message, +) + + +class TestCrashPropagation: + """Test cases for crash state propagation behavior.""" + + def test_mark_agent_crashed_no_propagation(self): + """Test that mark_agent_crashed does not propagate to parent contexts.""" + parent = AgentContext() + child = AgentContext(parent_context=parent) + + # Initially both contexts should have is_crashed as False + assert parent.is_crashed() is False + assert child.is_crashed() is False + + # Mark the child context as crashed + child.mark_agent_crashed("Child crashed") + + # Child should be crashed but parent should not + assert child.is_crashed() is True + assert child.agent_crashed_message == "Child crashed" + assert parent.is_crashed() is False + assert parent.agent_crashed_message is None + + def test_nested_crash_no_propagation(self): + """Test that crash state doesn't propagate through multiple levels of parent contexts.""" + grandparent = AgentContext() + parent = AgentContext(parent_context=grandparent) + child = AgentContext(parent_context=parent) + + # Initially all contexts should have is_crashed as False + assert grandparent.is_crashed() is False + assert parent.is_crashed() is False + assert child.is_crashed() is False + + # Mark the child context as crashed + child.mark_agent_crashed("Child crashed") + + # Only child should be crashed + assert child.is_crashed() is True + assert parent.is_crashed() is False + assert grandparent.is_crashed() is False + + def test_context_manager_crash_no_propagation(self): + """Test that crash states don't propagate when using context managers.""" + with agent_context() as outer: + with agent_context() as inner: + # Initially both contexts should have is_crashed as False + assert outer.is_crashed() is False + assert inner.is_crashed() is False + + # Mark the inner context as crashed + inner.mark_agent_crashed("Inner crashed") + + # Inner should be crashed but outer should not + assert inner.is_crashed() is True + assert outer.is_crashed() is False + + def test_utility_functions_for_crash_state(self): + """Test utility functions for crash state.""" + with agent_context() as outer: + with agent_context() as inner: + # Initially both contexts should have is_crashed as False + assert is_crashed() is False + assert get_crash_message() is None + + # Mark the current context (inner) as crashed + mark_agent_crashed("Utility function crash") + + # Current context should be crashed but outer should not + assert is_crashed() is True + assert get_crash_message() == "Utility function crash" + assert inner.is_crashed() is True + assert outer.is_crashed() is False \ No newline at end of file