import json import re from langchain_core.tools import BaseTool from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import BaseMessage from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import ( DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, ) from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError from ra_aid.llm import initialize_llm, validate_provider_env from ra_aid.logging_config import get_logger from ra_aid.tool_configs import get_all_tools from ra_aid.tool_leaderboard import supported_top_tool_models logger = get_logger(__name__) class FallbackHandler: """ FallbackHandler manages fallback logic when tool execution fails. It loads fallback models from configuration and validated provider settings, maintains failure counts, and triggers appropriate fallback methods for both prompt-based and function-calling tool invocations. It also resets internal counters when a tool call succeeds. """ def __init__(self, config, tools): """ Initialize the FallbackHandler with the given configuration and tools. Args: config (dict): Configuration dictionary that may include fallback settings. tools (list): List of available tools. """ self.config = config self.tools: list[BaseTool] = tools self.fallback_enabled = config.get("fallback_tool_enabled", True) self.fallback_tool_models = self._load_fallback_tool_models(config) self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) self.tool_failure_consecutive_failures = 0 self.failed_messages = set() self.tool_failure_used_fallbacks = set() self.current_failing_tool_name = "" self.current_tool_to_bind = None def _load_fallback_tool_models(self, config): """ Load and return fallback tool models based on the provided configuration. If the config specifies 'fallback_tool_models', those are used (assuming comma-separated names). Otherwise, this method filters the supported_top_tool_models based on provider environment validation, selecting up to FALLBACK_TOOL_MODEL_LIMIT models. Args: config (dict): Configuration dictionary. Returns: list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model. """ supported = [] skipped = [] for item in supported_top_tool_models: provider = item.get("provider") model_name = item.get("model") if validate_provider_env(provider): supported.append(item) if len(supported) == FALLBACK_TOOL_MODEL_LIMIT: break else: skipped.append(model_name) final_models = [] for item in supported: if "type" not in item: item["type"] = "prompt" item["model"] = item["model"].lower() final_models.append(item) message = "Fallback models selected: " + ", ".join( [m["model"] for m in final_models] ) if skipped: message += ( "\nSkipped top tool calling models due to missing provider ENV API keys: " + ", ".join(skipped) ) cpm(message, title="Fallback Models") return final_models def handle_failure(self, error: Exception, agent: CiaynAgent | CompiledGraph): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. Args: error (Exception): The exception raised during execution. If the exception has a 'base_message' attribute, that message is recorded. agent: The agent instance on which fallback may be executed. """ if not self.fallback_enabled: return None failed_tool_call_name = self.extract_failed_tool_name(error) if ( self.current_failing_tool_name and failed_tool_call_name != self.current_failing_tool_name ): logger.debug( "New failing tool name identified. Resetting consecutive tool failures." ) self.reset_fallback_handler() logger.debug( f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}" ) self.current_failing_tool_name = failed_tool_call_name self.current_tool_to_bind = self._find_tool_to_bind( agent, failed_tool_call_name ) if hasattr(error, "base_message") and error.base_message: self.failed_messages.add(str(error.base_message)) self.tool_failure_consecutive_failures += 1 logger.debug( f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {self.max_failures}" ) if ( self.tool_failure_consecutive_failures >= self.max_failures and self.fallback_tool_models ): logger.debug( "_handle_tool_failure: threshold reached, invoking fallback mechanism." ) return self.attempt_fallback() def attempt_fallback(self): """ Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method. Returns: The response from a fallback model if any, otherwise None. """ logger.error( f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}" ) cpm( f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.", title="Fallback Notification", ) for fallback_model in self.fallback_tool_models: if fallback_model.get("type", "prompt").lower() == "fc": response = self.attempt_fallback_function(fallback_model) else: response = self.attempt_fallback_prompt(fallback_model) if response: return response cpm("All fallback models have failed", title="Fallback Failed") return None def reset_fallback_handler(self): """ Reset the fallback handler's internal failure counters and clear the record of used fallback models. """ self.tool_failure_consecutive_failures = 0 self.failed_messages.clear() self.tool_failure_used_fallbacks.clear() self.fallback_tool_models = self._load_fallback_tool_models(self.config) def extract_failed_tool_name(self, error: ToolExecutionError): if error.tool_name: failed_tool_call_name = error.tool_name else: msg = str(error) logger.debug("Error message: %s", msg) match = re.search(r"name=['\"](\w+)['\"]", msg) if match: failed_tool_call_name = str(match.group(1)) logger.debug( "Extracted failed_tool_call_name using regex: %s", failed_tool_call_name, ) else: failed_tool_call_name = "Tool execution error" raise Exception("Fallback failed: Could not extract failed tool name.") return failed_tool_call_name def _find_tool_to_bind(self, agent, failed_tool_call_name): logger.debug(f"failed_tool_call_name={failed_tool_call_name}") tool_to_bind = None if hasattr(agent, "tools"): tool_to_bind = next( (t for t in agent.tools if t.func.__name__ == failed_tool_call_name), None, ) if tool_to_bind is None: all_tools = get_all_tools() tool_to_bind = next( (t for t in all_tools if t.func.__name__ == failed_tool_call_name), None, ) if tool_to_bind is None: # TODO: Would be nice to try fuzzy match or levenstein str match to find closest correspond tool name raise Exception( f"Fallback failed: {failed_tool_call_name} not found in all tools." ) return tool_to_bind def attempt_fallback_prompt(self, fallback_model): """ Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model. Args: fallback_model (dict): The fallback model to use. Returns: The response from the fallback model invocation, or None if failed. """ try: logger.debug(f"Trying fallback model: {fallback_model['model']}") simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) binded_model = simple_model.bind_tools( [self.current_tool_to_bind], tool_choice=self.current_failing_tool_name, ) retry_model = binded_model.with_retry( stop_after_attempt=RETRY_FALLBACK_COUNT ) response = retry_model.invoke(self.current_failing_tool_name) cpm(f"response={response}") self.tool_failure_used_fallbacks.add(fallback_model["model"]) tool_call = self.base_message_to_tool_call_dict(response) if tool_call: result = self.invoke_prompt_tool_call(tool_call) cpm(f"result={result}") logger.debug( "Prompt-based fallback executed successfully with model: " + fallback_model["model"] ) self.reset_fallback_handler() return result else: cpm( response.content if hasattr(response, "content") else response, title="Fallback Model Response: " + fallback_model["model"], ) return response except Exception as e: if isinstance(e, KeyboardInterrupt): raise logger.error( f"Prompt-based fallback with model {fallback_model['model']} failed: {e}" ) return None def attempt_fallback_function(self, fallback_model): """ Attempt a function-calling fallback by invoking the current failing tool with the given fallback model. Args: fallback_model (dict): The fallback model to use. Returns: The response from the fallback model invocation, or None if failed. """ try: logger.debug(f"Trying fallback model: {fallback_model['model']}") simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) binded_model = simple_model.bind_tools( [self.current_tool_to_bind], tool_choice=self.current_failing_tool_name, ) retry_model = binded_model.with_retry( stop_after_attempt=RETRY_FALLBACK_COUNT ) response = retry_model.invoke(self.current_failing_tool_name) cpm(f"response={response}") self.tool_failure_used_fallbacks.add(fallback_model["model"]) logger.debug( "Function-calling fallback executed successfully with model: " + fallback_model["model"] ) cpm( response.content if hasattr(response, "content") else response, title="Fallback Model Response: " + fallback_model["model"], ) return response except Exception as e: if isinstance(e, KeyboardInterrupt): raise logger.error( f"Function-calling fallback with model {fallback_model['model']} failed: {e}" ) return None def invoke_prompt_tool_call(self, tool_call_request: dict): """ Invoke a tool call from a prompt-based fallback response. Args: tool_call_request (dict): The tool call request containing keys 'type', 'name', and 'arguments'. Returns: The result of invoking the tool. """ tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools} name = tool_call_request["name"] arguments = tool_call_request["arguments"] return tool_name_to_tool[name].invoke(arguments) def base_message_to_tool_call_dict(self, response: BaseMessage): """ Extracts a tool call dictionary from a BaseMessage. Args: response: The response object containing tool call data. Returns: A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found, otherwise None. """ tool_calls = self.get_tool_calls(response) if tool_calls: if len(tool_calls) > 1: logger.warning("Multiple tool calls detected, using the first one") tool_call = tool_calls[0] return { "id": tool_call["id"], "type": tool_call["type"], "name": tool_call["function"]["name"], "arguments": self._parse_tool_arguments( tool_call["function"]["arguments"] ), } return None def _parse_tool_arguments(self, tool_arguments): """ Helper method to parse tool call arguments. If tool_arguments is a string, it returns the JSON-parsed dictionary. Otherwise, returns tool_arguments as is. """ if isinstance(tool_arguments, str): return json.loads(tool_arguments) return tool_arguments def get_tool_calls(self, response: BaseMessage): """ Extracts tool calls list from a fallback response. Args: response: The response object containing tool call data. Returns: The tool calls list if present, otherwise None. """ tool_calls = None if hasattr(response, "additional_kwargs") and response.additional_kwargs.get( "tool_calls" ): tool_calls = response.additional_kwargs.get("tool_calls") elif hasattr(response, "tool_calls"): tool_calls = response.tool_calls elif isinstance(response, dict) and response.get("additional_kwargs", {}).get( "tool_calls" ): tool_calls = response.get("additional_kwargs").get("tool_calls") return tool_calls