diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index cf85a0c..3aea6a8 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -9,7 +9,6 @@ import uuid from datetime import datetime from typing import Any, Dict, List, Literal, Optional, Sequence -from langgraph.graph.graph import CompiledGraph import litellm from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from langchain_core.language_models import BaseChatModel @@ -20,6 +19,7 @@ from langchain_core.messages import ( ) from langchain_core.tools import tool from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph.graph import CompiledGraph from langgraph.prebuilt import create_react_agent from langgraph.prebuilt.chat_agent_executor import AgentState from litellm import get_model_info @@ -876,9 +876,7 @@ def run_agent_with_retry( logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: - fallback_response = _handle_tool_execution_error( - fallback_handler, agent, e - ) + fallback_response = fallback_handler.handle_failure(e, agent) if fallback_response: prompt = original_prompt + "\n" + fallback_response continue @@ -895,42 +893,3 @@ def run_agent_with_retry( finally: _decrement_agent_depth() _restore_interrupt_handling(original_handler) - - -def _handle_tool_execution_error( - fallback_handler: FallbackHandler, - agent: CiaynAgent | CompiledGraph, - error: ToolExecutionError, -): - logger.debug("Entering _handle_tool_execution_error with error: %s", error) - if error.tool_name: - failed_tool_call_name = error.tool_name - logger.debug( - "Extracted failed_tool_call_name from error.tool_name: %s", - failed_tool_call_name, - ) - else: - import re - - msg = str(error) - logger.debug("Error message: %s", msg) - match = re.search(r"name=['\"](\w+)['\"]", msg) - if match: - failed_tool_call_name = match.group(1) - logger.debug( - "Extracted failed_tool_call_name using regex: %s", failed_tool_call_name - ) - else: - failed_tool_call_name = "Tool execution error" - logger.debug( - "Defaulting failed_tool_call_name to: %s", failed_tool_call_name - ) - logger.debug( - "Calling fallback_handler.handle_failure with failed_tool_call_name: %s", - failed_tool_call_name, - ) - fallback_response = fallback_handler.handle_failure( - failed_tool_call_name, error, agent - ) - logger.debug("Fallback response received: %s", fallback_response) - return fallback_response diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index aad96e6..7e62e1b 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -44,7 +44,9 @@ def print_agent_output(chunk: Dict[str, Any]) -> None: ) ) tool_name = getattr(msg, "name", None) - raise ToolExecutionError(err_msg, tool_name=tool_name) + cpm(f"type(msg): {type(msg)}") + cpm(f"msg: {msg}") + raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg) def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None: diff --git a/ra_aid/exceptions.py b/ra_aid/exceptions.py index d8bc532..34710d9 100644 --- a/ra_aid/exceptions.py +++ b/ra_aid/exceptions.py @@ -1,5 +1,9 @@ """Custom exceptions for RA.Aid.""" +from typing import Optional + +from langchain_core.messages import BaseMessage + class AgentInterrupt(Exception): """Exception raised when an agent's execution is interrupted. @@ -17,6 +21,13 @@ class ToolExecutionError(Exception): This exception is used to distinguish tool execution failures from other types of errors in the agent system. """ - def __init__(self, message, tool_name=None): + + def __init__( + self, + message: str, + base_message: Optional[BaseMessage] = None, + tool_name: Optional[str] = None, + ): super().__init__(message) + self.base_message = base_message self.tool_name = tool_name diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 31df7d5..5e80826 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -1,20 +1,22 @@ +import json +import re + from langchain_core.tools import BaseTool from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import BaseMessage -from ra_aid.console.output import cpm -import json - from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import ( DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, ) -from ra_aid.logging_config import get_logger -from ra_aid.tool_leaderboard import supported_top_tool_models -from rich.console import Console +from ra_aid.console.output import cpm +from ra_aid.exceptions import ToolExecutionError from ra_aid.llm import initialize_llm, validate_provider_env +from ra_aid.logging_config import get_logger +from ra_aid.tool_configs import get_all_tools +from ra_aid.tool_leaderboard import supported_top_tool_models logger = get_logger(__name__) @@ -41,9 +43,12 @@ class FallbackHandler: self.tools: list[BaseTool] = tools self.fallback_enabled = config.get("fallback_tool_enabled", True) self.fallback_tool_models = self._load_fallback_tool_models(config) + self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) self.tool_failure_consecutive_failures = 0 + self.failed_messages = set() self.tool_failure_used_fallbacks = set() - self.console = Console() + self.current_failing_tool_name = "" + self.current_tool_to_bind = None def _load_fallback_tool_models(self, config): """ @@ -87,66 +92,104 @@ class FallbackHandler: cpm(message, title="Fallback Models") return final_models - def handle_failure( - self, code: str, error: Exception, agent: CiaynAgent | CompiledGraph - ): + def handle_failure(self, error: Exception, agent: CiaynAgent | CompiledGraph): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. Args: - code (str): The code that failed to execute. - error (Exception): The exception raised during execution. - logger: Logger instance for logging. + error (Exception): The exception raised during execution. If the exception has a 'base_message' attribute, that message is recorded. agent: The agent instance on which fallback may be executed. """ - logger.debug( - f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}" - ) - self.tool_failure_consecutive_failures += 1 - max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) - logger.debug( - f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}" - ) + if not self.fallback_enabled: + return None + + failed_tool_call_name = self.extract_failed_tool_name(error) if ( - self.fallback_enabled - and self.tool_failure_consecutive_failures >= max_failures + self.current_failing_tool_name + and failed_tool_call_name != self.current_failing_tool_name + ): + logger.debug( + "New failing tool name identified. Resetting consecutive tool failures." + ) + self.reset_fallback_handler() + + logger.debug( + f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}" + ) + + self.current_failing_tool_name = failed_tool_call_name + self.current_tool_to_bind = self._find_tool_to_bind( + agent, failed_tool_call_name + ) + + if hasattr(error, "base_message") and error.base_message: + self.failed_messages.add(str(error.base_message)) + + self.tool_failure_consecutive_failures += 1 + logger.debug( + f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {self.max_failures}" + ) + + if ( + self.tool_failure_consecutive_failures >= self.max_failures and self.fallback_tool_models ): logger.debug( "_handle_tool_failure: threshold reached, invoking fallback mechanism." ) - return self.attempt_fallback(code, logger, agent) + return self.attempt_fallback() - def attempt_fallback(self, code: str, logger, agent): + def attempt_fallback(self): """ - Initiate the fallback process by selecting a fallback model and triggering the appropriate fallback method. + Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method. - Args: - code (str): The tool code that triggered the fallback. - logger: Logger instance for logging messages. - agent: The agent for which fallback is being executed. + Returns: + The response from a fallback model if any, otherwise None. """ - logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") - fallback_model = self.fallback_tool_models[0] - failed_tool_call_name = code logger.error( - f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" + f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}" ) cpm( - f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}.", + f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.", title="Fallback Notification", ) - if fallback_model.get("type", "prompt").lower() == "fc": - self.attempt_fallback_function(code, logger, agent) - else: - self.attempt_fallback_prompt(code, logger, agent) + for fallback_model in self.fallback_tool_models: + if fallback_model.get("type", "prompt").lower() == "fc": + response = self.attempt_fallback_function(fallback_model) + else: + response = self.attempt_fallback_prompt(fallback_model) + if response: + return response + cpm("All fallback models have failed", title="Fallback Failed") + return None def reset_fallback_handler(self): """ Reset the fallback handler's internal failure counters and clear the record of used fallback models. """ self.tool_failure_consecutive_failures = 0 + self.failed_messages.clear() self.tool_failure_used_fallbacks.clear() + self.fallback_tool_models = self._load_fallback_tool_models(self.config) + + def extract_failed_tool_name(self, error: ToolExecutionError): + if error.tool_name: + failed_tool_call_name = error.tool_name + else: + msg = str(error) + logger.debug("Error message: %s", msg) + match = re.search(r"name=['\"](\w+)['\"]", msg) + if match: + failed_tool_call_name = str(match.group(1)) + logger.debug( + "Extracted failed_tool_call_name using regex: %s", + failed_tool_call_name, + ) + else: + failed_tool_call_name = "Tool execution error" + raise Exception("Fallback failed: Could not extract failed tool name.") + + return failed_tool_call_name def _find_tool_to_bind(self, agent, failed_tool_call_name): logger.debug(f"failed_tool_call_name={failed_tool_call_name}") @@ -157,135 +200,108 @@ class FallbackHandler: None, ) if tool_to_bind is None: - from ra_aid.tool_configs import get_all_tools - all_tools = get_all_tools() tool_to_bind = next( (t for t in all_tools if t.func.__name__ == failed_tool_call_name), None, ) if tool_to_bind is None: - available = [t.func.__name__ for t in get_all_tools()] - logger.debug( - f"Failed to find tool: {failed_tool_call_name}. Available tools: {available}" + # TODO: Would be nice to try fuzzy match or levenstein str match to find closest correspond tool name + raise Exception( + f"Fallback failed: {failed_tool_call_name} not found in all tools." ) - raise Exception(f"Tool {failed_tool_call_name} not found in all tools.") return tool_to_bind - def attempt_fallback_prompt(self, code: str, logger, agent): + def attempt_fallback_prompt(self, fallback_model): """ - Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. - - This method tries each fallback model (with retry logic configured) until one successfully executes the code. + Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model. Args: - code (str): The tool code to invoke via fallback. - logger: Logger instance for logging messages. - agent: The agent instance to update with the new model upon success. + fallback_model (dict): The fallback model to use. Returns: - The response from the fallback model invocation. - - Raises: - Exception: If all prompt-based fallback models fail. + The response from the fallback model invocation, or None if failed. """ - logger.debug("Attempting prompt-based fallback using fallback models") - failed_tool_call_name = code - for fallback_model in self.fallback_tool_models: - try: - logger.debug(f"Trying fallback model: {fallback_model['model']}") - simple_model = initialize_llm( - fallback_model["provider"], fallback_model["model"] - ) - tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name) - binded_model = simple_model.bind_tools( - [tool_to_bind], tool_choice=failed_tool_call_name - ) - # retry_model = binded_model.with_retry( - # stop_after_attempt=RETRY_FALLBACK_COUNT - # ) - response = binded_model.invoke(code) - cpm(f"response={response}") - - self.tool_failure_used_fallbacks.add(fallback_model["model"]) - - tool_call = self.base_message_to_tool_call_dict(response) - if tool_call: - result = self.invoke_prompt_tool_call(tool_call) - cpm(f"result={result}") - logger.debug( - "Prompt-based fallback executed successfully with model: " - + fallback_model["model"] - ) - self.reset_fallback_handler() - return result - else: - cpm( - response.content if hasattr(response, "content") else response, - title="Fallback Model Response: " + fallback_model["model"], - ) - return response - except Exception as e: - if isinstance(e, KeyboardInterrupt): - raise - logger.error( - f"Prompt-based fallback with model {fallback_model['model']} failed: {e}" - ) - raise Exception("All prompt-based fallback models failed") - - def attempt_fallback_function(self, code: str, logger, agent): - """ - Attempt a function-calling fallback by iterating over fallback models and invoking the provided code. - - This method tries each fallback model (with retry logic configured) until one successfully executes the code. - - Args: - code (str): The tool code to invoke via fallback. - logger: Logger instance for logging messages. - agent: The agent instance to update with the new model upon success. - - Returns: - The response from the fallback model invocation. - - Raises: - Exception: If all function-calling fallback models fail. - """ - logger.debug("Attempting function-calling fallback using fallback models") - failed_tool_call_name = code - for fallback_model in self.fallback_tool_models: - try: - logger.debug(f"Trying fallback model: {fallback_model['model']}") - simple_model = initialize_llm( - fallback_model["provider"], fallback_model["model"] - ) - tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name) - binded_model = simple_model.bind_tools( - [tool_to_bind], tool_choice=failed_tool_call_name - ) - retry_model = binded_model.with_retry( - stop_after_attempt=RETRY_FALLBACK_COUNT - ) - response = retry_model.invoke(code) - cpm(f"response={response}") - self.tool_failure_used_fallbacks.add(fallback_model["model"]) - self.reset_fallback_handler() + try: + logger.debug(f"Trying fallback model: {fallback_model['model']}") + simple_model = initialize_llm( + fallback_model["provider"], fallback_model["model"] + ) + binded_model = simple_model.bind_tools( + [self.current_tool_to_bind], + tool_choice=self.current_failing_tool_name, + ) + retry_model = binded_model.with_retry( + stop_after_attempt=RETRY_FALLBACK_COUNT + ) + response = retry_model.invoke(self.current_failing_tool_name) + cpm(f"response={response}") + self.tool_failure_used_fallbacks.add(fallback_model["model"]) + tool_call = self.base_message_to_tool_call_dict(response) + if tool_call: + result = self.invoke_prompt_tool_call(tool_call) + cpm(f"result={result}") logger.debug( - "Function-calling fallback executed successfully with model: " + "Prompt-based fallback executed successfully with model: " + fallback_model["model"] ) - + self.reset_fallback_handler() + return result + else: cpm( response.content if hasattr(response, "content") else response, title="Fallback Model Response: " + fallback_model["model"], ) return response - except Exception as e: - if isinstance(e, KeyboardInterrupt): - raise - logger.error( - f"Function-calling fallback with model {fallback_model['model']} failed: {e}" - ) - raise Exception("All function-calling fallback models failed") + except Exception as e: + if isinstance(e, KeyboardInterrupt): + raise + logger.error( + f"Prompt-based fallback with model {fallback_model['model']} failed: {e}" + ) + return None + + def attempt_fallback_function(self, fallback_model): + """ + Attempt a function-calling fallback by invoking the current failing tool with the given fallback model. + + Args: + fallback_model (dict): The fallback model to use. + + Returns: + The response from the fallback model invocation, or None if failed. + """ + try: + logger.debug(f"Trying fallback model: {fallback_model['model']}") + simple_model = initialize_llm( + fallback_model["provider"], fallback_model["model"] + ) + binded_model = simple_model.bind_tools( + [self.current_tool_to_bind], + tool_choice=self.current_failing_tool_name, + ) + retry_model = binded_model.with_retry( + stop_after_attempt=RETRY_FALLBACK_COUNT + ) + response = retry_model.invoke(self.current_failing_tool_name) + cpm(f"response={response}") + self.tool_failure_used_fallbacks.add(fallback_model["model"]) + logger.debug( + "Function-calling fallback executed successfully with model: " + + fallback_model["model"] + ) + cpm( + response.content if hasattr(response, "content") else response, + title="Fallback Model Response: " + fallback_model["model"], + ) + return response + except Exception as e: + if isinstance(e, KeyboardInterrupt): + raise + logger.error( + f"Function-calling fallback with model {fallback_model['model']} failed: {e}" + ) + return None def invoke_prompt_tool_call(self, tool_call_request: dict): """ diff --git a/ra_aid/logging_config.py b/ra_aid/logging_config.py index fb3bf63..ba4609f 100644 --- a/ra_aid/logging_config.py +++ b/ra_aid/logging_config.py @@ -1,9 +1,10 @@ import logging import sys from typing import Optional + from rich.console import Console -from rich.panel import Panel from rich.markdown import Markdown +from rich.panel import Panel class PrettyHandler(logging.Handler): diff --git a/ra_aid/tool_configs.py b/ra_aid/tool_configs.py index 8fce691..b982613 100644 --- a/ra_aid/tool_configs.py +++ b/ra_aid/tool_configs.py @@ -1,3 +1,5 @@ +from langchain_core.tools import BaseTool + from ra_aid.tools import ( ask_expert, ask_human, @@ -61,11 +63,13 @@ def get_read_only_tools( return tools + def get_all_tools_simple(): """Return a list containing all available tools using existing group methods.""" return get_all_tools() -def get_all_tools(): + +def get_all_tools() -> list[BaseTool]: """Return a list containing all available tools from different groups.""" all_tools = [] all_tools.extend(get_read_only_tools()) @@ -176,7 +180,7 @@ def get_implementation_tools( return tools -def get_web_research_tools(expert_enabled: bool = True) -> list: +def get_web_research_tools(expert_enabled: bool = True): """Get the list of tools available for web research. Args: @@ -196,9 +200,7 @@ def get_web_research_tools(expert_enabled: bool = True) -> list: return tools -def get_chat_tools( - expert_enabled: bool = True, web_research_enabled: bool = False -) -> list: +def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False): """Get the list of tools available in chat mode. Chat mode includes research and implementation capabilities but excludes diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 7a02a85..2e54f5a 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -276,11 +276,19 @@ def test_get_model_token_limit_planner(mock_memory): token_limit = get_model_token_limit(config, "planner") assert token_limit == 120000 + # New tests for private helper methods in agent_utils.py + def test_setup_and_restore_interrupt_handling(): - import signal, threading - from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt + import signal + + from ra_aid.agent_utils import ( + _request_interrupt, + _restore_interrupt_handling, + _setup_interrupt_handling, + ) + original_handler = signal.getsignal(signal.SIGINT) handler = _setup_interrupt_handling() # Verify the SIGINT handler is set to _request_interrupt @@ -289,66 +297,93 @@ def test_setup_and_restore_interrupt_handling(): # Verify the SIGINT handler is restored to the original assert signal.getsignal(signal.SIGINT) == original_handler + def test_increment_and_decrement_agent_depth(): - from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory + from ra_aid.agent_utils import ( + _decrement_agent_depth, + _global_memory, + _increment_agent_depth, + ) + _global_memory["agent_depth"] = 10 _increment_agent_depth() assert _global_memory["agent_depth"] == 11 _decrement_agent_depth() assert _global_memory["agent_depth"] == 10 + def test_run_agent_stream(monkeypatch): - from ra_aid.agent_utils import _run_agent_stream, _global_memory + from ra_aid.agent_utils import _global_memory, _run_agent_stream + # Create a dummy agent that yields one chunk class DummyAgent: def stream(self, msg, cfg): yield {"content": "chunk1"} + dummy_agent = DummyAgent() # Set flags so that _run_agent_stream will reset them _global_memory["plan_completed"] = True _global_memory["task_completed"] = True _global_memory["completion_message"] = "existing" call_flag = {"called": False} + def fake_print_agent_output(chunk): call_flag["called"] = True - monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output) + + monkeypatch.setattr( + "ra_aid.agent_utils.print_agent_output", fake_print_agent_output + ) _run_agent_stream(dummy_agent, "dummy prompt", {}) assert call_flag["called"] assert _global_memory["plan_completed"] is False assert _global_memory["task_completed"] is False assert _global_memory["completion_message"] == "" + def test_execute_test_command_wrapper(monkeypatch): from ra_aid.agent_utils import _execute_test_command_wrapper + # Patch execute_test_command to return a testable tuple def fake_execute(config, orig, tests, auto): return (True, "new prompt", auto, tests + 1) + monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute) result = _execute_test_command_wrapper("orig", {}, 0, False) assert result == (True, "new prompt", False, 1) + def test_handle_api_error_valueerror(): - from ra_aid.agent_utils import _handle_api_error import pytest + + from ra_aid.agent_utils import _handle_api_error + # ValueError not containing "code" or "429" should be re-raised with pytest.raises(ValueError): _handle_api_error(ValueError("some error"), 0, 5, 1) + def test_handle_api_error_max_retries(): - from ra_aid.agent_utils import _handle_api_error import pytest + + from ra_aid.agent_utils import _handle_api_error + # When attempt reaches max retries, a RuntimeError should be raised with pytest.raises(RuntimeError): _handle_api_error(Exception("error code 429"), 4, 5, 1) + def test_handle_api_error_retry(monkeypatch): - from ra_aid.agent_utils import _handle_api_error import time + + from ra_aid.agent_utils import _handle_api_error + # Patch time.monotonic and time.sleep to simulate immediate delay expiration fake_time = [0] + def fake_monotonic(): fake_time[0] += 0.5 return fake_time[0] + monkeypatch.setattr(time, "monotonic", fake_monotonic) monkeypatch.setattr(time, "sleep", lambda s: None) # Should not raise error when attempt is lower than max retries diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index 6f0c285..c400a19 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -1,28 +1,41 @@ import unittest + from ra_aid.fallback_handler import FallbackHandler + class DummyLogger: def debug(self, msg): pass + def error(self, msg): pass + class DummyAgent: provider = "openai" tools = [] model = None + class TestFallbackHandler(unittest.TestCase): def setUp(self): - self.config = {"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"} + self.config = { + "max_tool_failures": 2, + "fallback_tool_models": "dummy-fallback-model", + } self.fallback_handler = FallbackHandler(self.config) self.logger = DummyLogger() self.agent = DummyAgent() def test_handle_failure_increments_counter(self): initial_failures = self.fallback_handler.tool_failure_consecutive_failures - self.fallback_handler.handle_failure("dummy_call()", Exception("Test error"), self.logger, self.agent) - self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1) + self.fallback_handler.handle_failure( + "dummy_call()", Exception("Test error"), self.logger, self.agent + ) + self.assertEqual( + self.fallback_handler.tool_failure_consecutive_failures, + initial_failures + 1, + ) def test_attempt_fallback_resets_counter(self): # Monkey-patch dummy functions for fallback components @@ -30,6 +43,7 @@ class TestFallbackHandler(unittest.TestCase): class DummyModel: def bind_tools(self, tools, tool_choice): pass + return DummyModel() def dummy_merge_chat_history(): @@ -39,6 +53,7 @@ class TestFallbackHandler(unittest.TestCase): return True import ra_aid.llm as llm + original_initialize = llm.initialize_llm original_merge = llm.merge_chat_history original_validate = llm.validate_provider_env @@ -47,12 +62,15 @@ class TestFallbackHandler(unittest.TestCase): llm.validate_provider_env = dummy_validate_provider_env self.fallback_handler.tool_failure_consecutive_failures = 2 - self.fallback_handler.attempt_fallback("dummy_tool_call()", self.logger, self.agent) + self.fallback_handler.attempt_fallback( + "dummy_tool_call()", self.logger, self.agent + ) self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0) llm.initialize_llm = original_initialize llm.merge_chat_history = original_merge llm.validate_provider_env = original_validate + if __name__ == "__main__": unittest.main()