diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 26e4f50..565d6a9 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -933,7 +933,6 @@ def run_agent_with_retry( original_prompt, config, test_attempts, auto_test ) ) - cpm(f"res:{should_break, prompt, auto_test, test_attempts}") if should_break: break if prompt != original_prompt: diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 3b7bac0..c2cf07f 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -157,10 +157,16 @@ class FallbackHandler: for fallback_model in self.fallback_tool_models: result_list = self.invoke_fallback(fallback_model) if result_list: - # msg_list_response = [SystemMessage(str(msg)) for msg in result_list] return result_list - cpm("All fallback models have failed", title="Fallback Failed") - return None + + cpm("All fallback models have failed.", title="Fallback Failed") + + current_failing_tool_name = self.current_failing_tool_name + self.reset_fallback_handler() + + raise FallbackToolExecutionError( + f"All fallback models have failed for tool: {current_failing_tool_name}" + ) def reset_fallback_handler(self): """ diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index c400a19..3a2edcf 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -22,16 +22,17 @@ class TestFallbackHandler(unittest.TestCase): self.config = { "max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model", + "experimental_fallback_handler": True, } - self.fallback_handler = FallbackHandler(self.config) + self.fallback_handler = FallbackHandler(self.config, []) self.logger = DummyLogger() self.agent = DummyAgent() def test_handle_failure_increments_counter(self): + from ra_aid.exceptions import ToolExecutionError initial_failures = self.fallback_handler.tool_failure_consecutive_failures - self.fallback_handler.handle_failure( - "dummy_call()", Exception("Test error"), self.logger, self.agent - ) + error_obj = ToolExecutionError("Test error", base_message="dummy_call()", tool_name="dummy_tool") + self.fallback_handler.handle_failure(error_obj, self.agent) self.assertEqual( self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1, @@ -62,9 +63,7 @@ class TestFallbackHandler(unittest.TestCase): llm.validate_provider_env = dummy_validate_provider_env self.fallback_handler.tool_failure_consecutive_failures = 2 - self.fallback_handler.attempt_fallback( - "dummy_tool_call()", self.logger, self.agent - ) + self.fallback_handler.attempt_fallback() self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0) llm.initialize_llm = original_initialize diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 6789132..7f96ed4 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -121,7 +121,7 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch def test_initialize_expert_unsupported_provider(clean_env): """Test error handling for unsupported provider in expert mode.""" - with pytest.raises(ValueError, match=r"Unsupported provider: unknown"): + with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"): initialize_expert_llm("unknown", "model") @@ -197,7 +197,7 @@ def test_initialize_unsupported_provider(clean_env): """Test initialization with unsupported provider raises ValueError""" with pytest.raises(ValueError) as exc_info: initialize_llm("unsupported", "model") - assert str(exc_info.value) == "Unsupported provider: unsupported" + assert str(exc_info.value) == "Missing required environment variable for provider: unsupported" def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemini):