diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 565d6a9..5aff588 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -885,7 +885,7 @@ def _handle_fallback_response( return fallback_response = fallback_handler.handle_failure(error, agent) if fallback_response and agent_type == "React": - msg_list_response = [SystemMessage(str(msg)) for msg in fallback_response] + msg_list_response = [HumanMessage(str(msg)) for msg in fallback_response] msg_list.extend(msg_list_response) diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index c2cf07f..3d61098 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -147,7 +147,7 @@ class FallbackHandler: Returns: List of [raw_llm_response (SystemMessage), tool_call_result (SystemMessage)] or None. """ - logger.error( + logger.debug( f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}" ) cpm( @@ -323,10 +323,15 @@ class FallbackHandler: Returns: The result of invoking the tool. """ - tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools} + tool_name_to_tool = {getattr(tool.func, "__name__", None): tool for tool in self.tools} name = tool_call_request["name"] arguments = tool_call_request["arguments"] - return tool_name_to_tool[name].invoke(arguments) + if name in tool_name_to_tool: + return tool_name_to_tool[name].invoke(arguments) + elif self.current_tool_to_bind is not None and getattr(self.current_tool_to_bind.func, "__name__", None) == name: + return self.current_tool_to_bind.invoke(arguments) + else: + raise Exception(f"Tool '{name}' not found in available tools.") def base_message_to_tool_call_dict(self, response: BaseMessage): """ diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index 31f5201..1a87608 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -84,6 +84,216 @@ class TestFallbackHandler(unittest.TestCase): llm.merge_chat_history = original_merge llm.validate_provider_env = original_validate + def test_load_fallback_tool_models(self): + import ra_aid.fallback_handler as fh + original_supported = fh.supported_top_tool_models + fh.supported_top_tool_models = [ + {"provider": "dummy", "model": "dummy_model", "type": "prompt"} + ] + models = self.fallback_handler._load_fallback_tool_models(self.config) + self.assertIsInstance(models, list) + fh.supported_top_tool_models = original_supported + + def test_extract_failed_tool_name(self): + from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError + # Case when tool_name is provided + error1 = ToolExecutionError("Error", base_message="dummy", tool_name="dummy_tool") + name1 = self.fallback_handler.extract_failed_tool_name(error1) + self.assertEqual(name1, "dummy_tool") + # Case when tool_name is not provided but regex works + error2 = ToolExecutionError("error with name=\"test_tool\"") + name2 = self.fallback_handler.extract_failed_tool_name(error2) + self.assertEqual(name2, "test_tool") + # Case when regex fails and exception is raised + error3 = ToolExecutionError("no tool name here") + with self.assertRaises(FallbackToolExecutionError): + self.fallback_handler.extract_failed_tool_name(error3) + + def test_find_tool_to_bind(self): + # Create a dummy tool to be found + class DummyTool: + def invoke(self, args): + return "result" + class DummyWrapper: + def __init__(self, func): + self.func = func + def dummy_func(args): + return "result" + dummy_tool = DummyTool() + dummy_wrapper = DummyWrapper(dummy_func) + self.agent.tools.append(dummy_wrapper) + tool = self.fallback_handler._find_tool_to_bind(self.agent, dummy_func.__name__) + self.assertIsNotNone(tool) + self.assertEqual(tool.func.__name__, dummy_func.__name__) + + def test_bind_tool_model(self): + # Setup a dummy simple_model with bind_tools method + class DummyModel: + def bind_tools(self, tools, tool_choice=None): + self.bound = True + self.tools = tools + self.tool_choice = tool_choice + return self + def with_retry(self, stop_after_attempt): + return self + def invoke(self, msg_list): + return "dummy_response" + dummy_model = DummyModel() + # Set current tool for binding + class DummyTool: + def invoke(self, args): + return "result" + self.fallback_handler.current_tool_to_bind = DummyTool() + self.fallback_handler.current_failing_tool_name = "test_tool" + # Test with force calling ("fc") type + fallback_model_fc = {"type": "fc"} + bound_model_fc = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_fc) + self.assertTrue(hasattr(bound_model_fc, "tool_choice")) + self.assertEqual(bound_model_fc.tool_choice, "test_tool") + # Test with prompt type + fallback_model_prompt = {"type": "prompt"} + bound_model_prompt = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_prompt) + self.assertTrue(bound_model_prompt.tool_choice is None) + + def test_invoke_fallback(self): + from unittest.mock import patch + import os + import ra_aid.llm as llm + + # Successful fallback scenario with proper API key set + with patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}), \ + patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \ + patch("ra_aid.fallback_handler.validate_provider_env", return_value=True), \ + patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm: + class DummyModel: + def bind_tools(self, tools, tool_choice=None): + return self + def with_retry(self, stop_after_attempt): + return self + def invoke(self, msg_list): + return DummyResponse() + class DummyResponse: + additional_kwargs = {"tool_calls": [{"id": "1", "type": "test", "function": {"name": "dummy_tool", "arguments": "{\"a\":1}"}}]} + def dummy_initialize_llm(provider, model_name): + return DummyModel() + mock_init_llm.side_effect = dummy_initialize_llm + # Set current tool for fallback + class DummyTool: + def invoke(self, args): + return "tool_result" + self.fallback_handler.current_tool_to_bind = DummyTool() + self.fallback_handler.current_failing_tool_name = "dummy_tool" + # Add dummy tool for lookup in invoke_prompt_tool_call + self.fallback_handler.tools.append( + type( + "DummyToolWrapper", + (), + { + "func": type("DummyToolFunc", (), {"__name__": "dummy_tool"})(), + "invoke": lambda self, args: "tool_result", + }, + ) + ) + result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"}) + self.assertIsInstance(result, list) + self.assertEqual(result[1], "tool_result") + + # Failed fallback scenario due to missing API key (simulate by empty environment) + with patch.dict(os.environ, {}, clear=True), \ + patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \ + patch("ra_aid.fallback_handler.validate_provider_env", return_value=False), \ + patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm: + class FailingDummyModel: + def bind_tools(self, tools, tool_choice=None): + return self + def with_retry(self, stop_after_attempt): + return self + def invoke(self, msg_list): + raise Exception("API key missing") + def failing_initialize_llm(provider, model_name): + return FailingDummyModel() + mock_init_llm.side_effect = failing_initialize_llm + fallback_result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"}) + self.assertIsNone(fallback_result) + + # Test that the overall fallback mechanism raises FallbackToolExecutionError when all models fail + # Set failure count to trigger the fallback attempt in attempt_fallback + from ra_aid.exceptions import FallbackToolExecutionError + self.fallback_handler.tool_failure_consecutive_failures = self.fallback_handler.max_failures + with self.assertRaises(FallbackToolExecutionError) as cm: + self.fallback_handler.attempt_fallback() + self.assertIn("All fallback models have failed", str(cm.exception)) + + def test_construct_prompt_msg_list(self): + msgs = self.fallback_handler.construct_prompt_msg_list() + from ra_aid.fallback_handler import SystemMessage, HumanMessage + self.assertTrue(any(isinstance(m, SystemMessage) for m in msgs)) + self.assertTrue(any(isinstance(m, HumanMessage) for m in msgs)) + # Test with failed_messages added + self.fallback_handler.failed_messages.append("failed_msg") + msgs_with_fail = self.fallback_handler.construct_prompt_msg_list() + self.assertTrue(any("failed_msg" in str(m) for m in msgs_with_fail)) + + def test_invoke_prompt_tool_call(self): + # Create dummy tool function + def dummy_tool_func(args): + return "invoked_result" + dummy_tool_func.__name__ = "dummy_tool" + # Create wrapper class + class DummyToolWrapper: + def __init__(self, func): + self.func = func + def invoke(self, args): + return self.func(args) + dummy_wrapper = DummyToolWrapper(dummy_tool_func) + self.fallback_handler.tools = [dummy_wrapper] + tool_call_req = {"name": "dummy_tool", "arguments": {"x": 42}} + result = self.fallback_handler.invoke_prompt_tool_call(tool_call_req) + self.assertEqual(result, "invoked_result") + + def test_base_message_to_tool_call_dict(self): + dummy_tool_call = { + "id": "123", + "type": "test", + "function": {"name": "dummy_tool", "arguments": "{\"x\":42}"} + } + DummyResponse = type("DummyResponse", (), {"additional_kwargs": {"tool_calls": [dummy_tool_call]}}) + result = self.fallback_handler.base_message_to_tool_call_dict(DummyResponse) + self.assertEqual(result["id"], "123") + self.assertEqual(result["name"], "dummy_tool") + self.assertEqual(result["arguments"], {"x": 42}) + + def test_parse_tool_arguments(self): + args_str = '{"a": 1}' + parsed = self.fallback_handler._parse_tool_arguments(args_str) + self.assertEqual(parsed, {"a": 1}) + args_dict = {"b": 2} + parsed_dict = self.fallback_handler._parse_tool_arguments(args_dict) + self.assertEqual(parsed_dict, {"b": 2}) + + def test_get_tool_calls(self): + DummyResponse = type("DummyResponse", (), {})() + DummyResponse.additional_kwargs = {"tool_calls": [{"id": "1"}]} + calls = self.fallback_handler.get_tool_calls(DummyResponse) + self.assertEqual(calls, [{"id": "1"}]) + DummyResponse2 = type("DummyResponse2", (), {"tool_calls": [{"id": "2"}]})() + calls2 = self.fallback_handler.get_tool_calls(DummyResponse2) + self.assertEqual(calls2, [{"id": "2"}]) + dummy_dict = {"additional_kwargs": {"tool_calls": [{"id": "3"}]}} + calls3 = self.fallback_handler.get_tool_calls(dummy_dict) + self.assertEqual(calls3, [{"id": "3"}]) + + def test_handle_failure_response(self): + from ra_aid.exceptions import ToolExecutionError + def dummy_handle_failure(error, agent): + return ["fallback_response"] + self.fallback_handler.handle_failure = dummy_handle_failure + response = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React") + from ra_aid.fallback_handler import SystemMessage + self.assertTrue(all(isinstance(m, SystemMessage) for m in response)) + response_non = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other") + self.assertIsNone(response_non) + if __name__ == "__main__": unittest.main()