RA.Aid/tests/ra_aid/test_fallback_handler.py

373 lines
14 KiB
Python

import unittest
from ra_aid.exceptions import FallbackToolExecutionError
from ra_aid.fallback_handler import FallbackHandler
class DummyLogger:
def debug(self, msg):
pass
def error(self, msg):
pass
class DummyAgent:
provider = "openai"
tools = []
model = None
class TestFallbackHandler(unittest.TestCase):
def setUp(self):
self.config = {
"max_tool_failures": 2,
"fallback_tool_models": "dummy-fallback-model",
"experimental_fallback_handler": True,
}
self.fallback_handler = FallbackHandler(self.config, [])
self.logger = DummyLogger()
self.agent = DummyAgent()
def dummy_tool():
pass
class DummyToolWrapper:
def __init__(self, func):
self.func = func
self.agent.tools = [DummyToolWrapper(dummy_tool)]
def test_handle_failure_increments_counter(self):
from ra_aid.exceptions import ToolExecutionError
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
error_obj = ToolExecutionError(
"Test error", base_message="dummy_call()", tool_name="dummy_tool"
)
self.fallback_handler.handle_failure(error_obj, self.agent, [])
self.assertEqual(
self.fallback_handler.tool_failure_consecutive_failures,
initial_failures + 1,
)
def test_attempt_fallback_resets_counter(self):
# Monkey-patch dummy functions for fallback components
def dummy_initialize_llm(provider, model_name, temperature=None):
class DummyModel:
def bind_tools(self, tools, tool_choice):
pass
return DummyModel()
def dummy_validate_provider_env(provider):
return True
import ra_aid.llm as llm
original_initialize = llm.initialize_llm
original_validate = llm.validate_provider_env
llm.initialize_llm = dummy_initialize_llm
llm.validate_provider_env = dummy_validate_provider_env
self.fallback_handler.tool_failure_consecutive_failures = 2
with self.assertRaises(FallbackToolExecutionError):
self.fallback_handler.attempt_fallback()
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
llm.initialize_llm = original_initialize
llm.validate_provider_env = original_validate
def test_load_fallback_tool_models(self):
import ra_aid.fallback_handler as fh
original_supported = fh.supported_top_tool_models
fh.supported_top_tool_models = [
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
]
models = self.fallback_handler._load_fallback_tool_models(self.config)
self.assertIsInstance(models, list)
fh.supported_top_tool_models = original_supported
def test_extract_failed_tool_name(self):
from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError
# Case when tool_name is provided
error1 = ToolExecutionError(
"Error", base_message="dummy", tool_name="dummy_tool"
)
name1 = self.fallback_handler.extract_failed_tool_name(error1)
self.assertEqual(name1, "dummy_tool")
# Case when tool_name is not provided but regex works
error2 = ToolExecutionError('error with name="test_tool"')
name2 = self.fallback_handler.extract_failed_tool_name(error2)
self.assertEqual(name2, "test_tool")
# Case when regex fails and exception is raised
error3 = ToolExecutionError("no tool name here")
with self.assertRaises(FallbackToolExecutionError):
self.fallback_handler.extract_failed_tool_name(error3)
def test_find_tool_to_bind(self):
class DummyWrapper:
def __init__(self, func):
self.func = func
def dummy_func(_args):
return "result"
dummy_wrapper = DummyWrapper(dummy_func)
self.agent.tools.append(dummy_wrapper)
tool = self.fallback_handler._find_tool_to_bind(self.agent, dummy_func.__name__)
self.assertIsNotNone(tool)
self.assertEqual(tool.func.__name__, dummy_func.__name__)
def test_bind_tool_model(self):
# Setup a dummy simple_model with bind_tools method
class DummyModel:
def bind_tools(self, tools, tool_choice=None):
self.bound = True
self.tools = tools
self.tool_choice = tool_choice
return self
def with_retry(self, stop_after_attempt):
return self
def invoke(self, msg_list):
return "dummy_response"
dummy_model = DummyModel()
# Set current tool for binding
class DummyTool:
def invoke(self, args):
return "result"
self.fallback_handler.current_tool_to_bind = DummyTool()
self.fallback_handler.current_failing_tool_name = "test_tool"
# Test with force calling ("fc") type
fallback_model_fc = {"type": "fc"}
bound_model_fc = self.fallback_handler._bind_tool_model(
dummy_model, fallback_model_fc
)
self.assertTrue(hasattr(bound_model_fc, "tool_choice"))
self.assertEqual(bound_model_fc.tool_choice, "test_tool")
# Test with prompt type
fallback_model_prompt = {"type": "prompt"}
bound_model_prompt = self.fallback_handler._bind_tool_model(
dummy_model, fallback_model_prompt
)
self.assertTrue(bound_model_prompt.tool_choice is None)
def test_invoke_fallback(self):
import os
from unittest.mock import patch
# Successful fallback scenario with proper API key set
with (
patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}),
patch(
"ra_aid.fallback_handler.supported_top_tool_models",
new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}],
),
patch("ra_aid.fallback_handler.validate_provider_env", return_value=True),
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm,
):
class DummyModel:
def bind_tools(self, tools, tool_choice=None):
return self
def with_retry(self, stop_after_attempt):
return self
def invoke(self, msg_list):
return DummyResponse()
class DummyResponse:
additional_kwargs = {
"tool_calls": [
{
"id": "1",
"type": "test",
"function": {"name": "dummy_tool", "arguments": '{"a":1}'},
}
]
}
def dummy_initialize_llm(provider, model_name):
return DummyModel()
mock_init_llm.side_effect = dummy_initialize_llm
# Set current tool for fallback
class DummyTool:
def invoke(self, args):
return "tool_result"
self.fallback_handler.current_tool_to_bind = DummyTool()
self.fallback_handler.current_failing_tool_name = "dummy_tool"
# Add dummy tool for lookup in invoke_prompt_tool_call
self.fallback_handler.tools.append(
type(
"DummyToolWrapper",
(),
{
"func": type("DummyToolFunc", (), {"__name__": "dummy_tool"})(),
"invoke": lambda self, args=None: "tool_result",
},
)
)
result = self.fallback_handler.invoke_fallback(
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
)
self.assertIsInstance(result, list)
self.assertEqual(result[1], "tool_result")
# Failed fallback scenario due to missing API key (simulate by empty environment)
with (
patch.dict(os.environ, {}, clear=True),
patch(
"ra_aid.fallback_handler.supported_top_tool_models",
new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}],
),
patch("ra_aid.fallback_handler.validate_provider_env", return_value=False),
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm,
):
class FailingDummyModel:
def bind_tools(self, tools, tool_choice=None):
return self
def with_retry(self, stop_after_attempt):
return self
def invoke(self, msg_list):
raise Exception("API key missing")
def failing_initialize_llm(provider, model_name):
return FailingDummyModel()
mock_init_llm.side_effect = failing_initialize_llm
fallback_result = self.fallback_handler.invoke_fallback(
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
)
self.assertIsNone(fallback_result)
# Test that the overall fallback mechanism raises FallbackToolExecutionError when all models fail
# Set failure count to trigger the fallback attempt in attempt_fallback
from ra_aid.exceptions import FallbackToolExecutionError
self.fallback_handler.tool_failure_consecutive_failures = (
self.fallback_handler.max_failures
)
with self.assertRaises(FallbackToolExecutionError) as cm:
self.fallback_handler.attempt_fallback()
self.assertIn("All fallback models have failed", str(cm.exception))
def test_construct_prompt_msg_list(self):
msgs = self.fallback_handler.construct_prompt_msg_list()
from ra_aid.fallback_handler import HumanMessage, SystemMessage
self.assertTrue(any(isinstance(m, SystemMessage) for m in msgs))
self.assertTrue(any(isinstance(m, HumanMessage) for m in msgs))
# Test with failed_messages added
self.fallback_handler.failed_messages.append("failed_msg")
msgs_with_fail = self.fallback_handler.construct_prompt_msg_list()
self.assertTrue(any("failed_msg" in str(m) for m in msgs_with_fail))
def test_invoke_prompt_tool_call(self):
# Create dummy tool function
def dummy_tool_func(args):
return "invoked_result"
dummy_tool_func.__name__ = "dummy_tool"
# Create wrapper class
class DummyToolWrapper:
def __init__(self, func):
self.func = func
def invoke(self, args):
return self.func(args)
dummy_wrapper = DummyToolWrapper(dummy_tool_func)
self.fallback_handler.tools = [dummy_wrapper]
tool_call_req = {"name": "dummy_tool", "arguments": {"x": 42}}
result = self.fallback_handler.invoke_prompt_tool_call(tool_call_req)
self.assertEqual(result, "invoked_result")
def test_base_message_to_tool_call_dict(self):
dummy_tool_call = {
"id": "123",
"type": "test",
"function": {"name": "dummy_tool", "arguments": '{"x":42}'},
}
DummyResponse = type(
"DummyResponse",
(),
{"additional_kwargs": {"tool_calls": [dummy_tool_call]}},
)
result = self.fallback_handler.base_message_to_tool_call_dict(DummyResponse)
self.assertEqual(result["id"], "123")
self.assertEqual(result["name"], "dummy_tool")
self.assertEqual(result["arguments"], {"x": 42})
def test_parse_tool_arguments(self):
args_str = '{"a": 1}'
parsed = self.fallback_handler._parse_tool_arguments(args_str)
self.assertEqual(parsed, {"a": 1})
args_dict = {"b": 2}
parsed_dict = self.fallback_handler._parse_tool_arguments(args_dict)
self.assertEqual(parsed_dict, {"b": 2})
def test_get_tool_calls(self):
DummyResponse = type("DummyResponse", (), {})()
DummyResponse.additional_kwargs = {"tool_calls": [{"id": "1"}]}
calls = self.fallback_handler.get_tool_calls(DummyResponse)
self.assertEqual(calls, [{"id": "1"}])
DummyResponse2 = type("DummyResponse2", (), {"tool_calls": [{"id": "2"}]})()
calls2 = self.fallback_handler.get_tool_calls(DummyResponse2)
self.assertEqual(calls2, [{"id": "2"}])
dummy_dict = {"additional_kwargs": {"tool_calls": [{"id": "3"}]}}
calls3 = self.fallback_handler.get_tool_calls(dummy_dict)
self.assertEqual(calls3, [{"id": "3"}])
def test_handle_failure_response(self):
from ra_aid.exceptions import ToolExecutionError
def dummy_handle_failure(error, agent):
return ["fallback_response"]
self.fallback_handler.handle_failure = dummy_handle_failure
response = self.fallback_handler.handle_failure_response(
ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React"
)
from ra_aid.fallback_handler import SystemMessage
self.assertTrue(all(isinstance(m, SystemMessage) for m in response))
response_non = self.fallback_handler.handle_failure_response(
ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other"
)
self.assertIsNone(response_non)
def test_init_msg_list_non_overlapping(self):
# Test when the first two and last two messages do not overlap.
full_list = ["msg1", "msg2", "msg3", "msg4", "msg5"]
self.fallback_handler.init_msg_list(full_list)
# Expected merged list: first two ("msg1", "msg2") plus last two ("msg4", "msg5")
self.assertEqual(
self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"]
)
def test_init_msg_list_with_overlap(self):
# Test when the last two messages overlap with the first two.
full_list = ["msg1", "msg2", "msg1", "msg3"]
self.fallback_handler.init_msg_list(full_list)
# Expected merged list: first two ("msg1", "msg2") plus "msg3" from the last two, since "msg1" was already present.
self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg3"])
if __name__ == "__main__":
unittest.main()