from typing import Dict from langchain_core.tools import BaseTool from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import BaseMessage from ra_aid.console.output import cpm import json from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import ( DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, ) from ra_aid.logging_config import get_logger from ra_aid.tool_leaderboard import supported_top_tool_models from rich.console import Console from ra_aid.llm import initialize_llm, validate_provider_env # from langgraph.graph.message import BaseMessage, BaseMessageChunk # from langgraph.prebuilt import ToolNode logger = get_logger(__name__) class FallbackHandler: """ FallbackHandler manages fallback logic when tool execution fails. It loads fallback models from configuration and validated provider settings, maintains failure counts, and triggers appropriate fallback methods for both prompt-based and function-calling tool invocations. It also resets internal counters when a tool call succeeds. """ def __init__(self, config, tools): """ Initialize the FallbackHandler with the given configuration and tools. Args: config (dict): Configuration dictionary that may include fallback settings. tools (list): List of available tools. """ self.config = config self.tools: list[BaseTool] = tools self.fallback_enabled = config.get("fallback_tool_enabled", True) self.fallback_tool_models = self._load_fallback_tool_models(config) self.tool_failure_consecutive_failures = 0 self.tool_failure_used_fallbacks = set() self.console = Console() def _load_fallback_tool_models(self, config): """ Load and return fallback tool models based on the provided configuration. If the config specifies 'fallback_tool_models', those are used (assuming comma-separated names). Otherwise, this method filters the supported_top_tool_models based on provider environment validation, selecting up to FALLBACK_TOOL_MODEL_LIMIT models. Args: config (dict): Configuration dictionary. Returns: list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model. """ supported = [] skipped = [] for item in supported_top_tool_models: provider = item.get("provider") model_name = item.get("model") if validate_provider_env(provider): supported.append(item) if len(supported) == FALLBACK_TOOL_MODEL_LIMIT: break else: skipped.append(model_name) final_models = [] for item in supported: if "type" not in item: item["type"] = "prompt" item["model"] = item["model"].lower() final_models.append(item) message = "Fallback models selected: " + ", ".join( [m["model"] for m in final_models] ) if skipped: message += ( "\nSkipped top tool calling models due to missing provider ENV API keys: " + ", ".join(skipped) ) cpm(message, title="Fallback Models") return final_models def handle_failure( self, code: str, error: Exception, agent: CiaynAgent | CompiledGraph ): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. Args: code (str): The code that failed to execute. error (Exception): The exception raised during execution. logger: Logger instance for logging. agent: The agent instance on which fallback may be executed. """ logger.debug( f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}" ) self.tool_failure_consecutive_failures += 1 max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) logger.debug( f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}" ) if ( self.fallback_enabled and self.tool_failure_consecutive_failures >= max_failures and self.fallback_tool_models ): logger.debug( "_handle_tool_failure: threshold reached, invoking fallback mechanism." ) return self.attempt_fallback(code, logger, agent) def attempt_fallback(self, code: str, logger, agent): """ Initiate the fallback process by selecting a fallback model and triggering the appropriate fallback method. Args: code (str): The tool code that triggered the fallback. logger: Logger instance for logging messages. agent: The agent for which fallback is being executed. """ logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") fallback_model = self.fallback_tool_models[0] failed_tool_call_name = code logger.error( f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" ) cpm( f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}.", title="Fallback Notification", ) if fallback_model.get("type", "prompt").lower() == "fc": self.attempt_fallback_function(code, logger, agent) else: self.attempt_fallback_prompt(code, logger, agent) def reset_fallback_handler(self): """ Reset the fallback handler's internal failure counters and clear the record of used fallback models. """ self.tool_failure_consecutive_failures = 0 self.tool_failure_used_fallbacks.clear() def _find_tool_to_bind(self, agent, failed_tool_call_name): logger.debug(f"failed_tool_call_name={failed_tool_call_name}") tool_to_bind = None if hasattr(agent, "tools"): tool_to_bind = next( (t for t in agent.tools if t.func.__name__ == failed_tool_call_name), None, ) if tool_to_bind is None: from ra_aid.tool_configs import get_all_tools all_tools = get_all_tools() tool_to_bind = next( (t for t in all_tools if t.func.__name__ == failed_tool_call_name), None, ) if tool_to_bind is None: available = [t.func.__name__ for t in get_all_tools()] logger.debug( f"Failed to find tool: {failed_tool_call_name}. Available tools: {available}" ) raise Exception(f"Tool {failed_tool_call_name} not found in all tools.") return tool_to_bind def attempt_fallback_prompt(self, code: str, logger, agent): """ Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. This method tries each fallback model (with retry logic configured) until one successfully executes the code. Args: code (str): The tool code to invoke via fallback. logger: Logger instance for logging messages. agent: The agent instance to update with the new model upon success. Returns: The response from the fallback model invocation. Raises: Exception: If all prompt-based fallback models fail. """ logger.debug("Attempting prompt-based fallback using fallback models") failed_tool_call_name = code for fallback_model in self.fallback_tool_models: try: logger.debug(f"Trying fallback model: {fallback_model['model']}") simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name) binded_model = simple_model.bind_tools( [tool_to_bind], tool_choice=failed_tool_call_name ) # retry_model = binded_model.with_retry( # stop_after_attempt=RETRY_FALLBACK_COUNT # ) response = binded_model.invoke(code) cpm(f"response={response}") self.tool_failure_used_fallbacks.add(fallback_model["model"]) tool_call = self.base_message_to_tool_call_dict(response) if tool_call: result = self.invoke_prompt_tool_call(tool_call) cpm(f"result={result}") logger.debug( "Prompt-based fallback executed successfully with model: " + fallback_model["model"] ) self.reset_fallback_handler() return result else: cpm( response.content if hasattr(response, "content") else response, title="Fallback Model Response: " + fallback_model["model"], ) return response except Exception as e: if isinstance(e, KeyboardInterrupt): raise logger.error( f"Prompt-based fallback with model {fallback_model['model']} failed: {e}" ) raise Exception("All prompt-based fallback models failed") def attempt_fallback_function(self, code: str, logger, agent): """ Attempt a function-calling fallback by iterating over fallback models and invoking the provided code. This method tries each fallback model (with retry logic configured) until one successfully executes the code. Args: code (str): The tool code to invoke via fallback. logger: Logger instance for logging messages. agent: The agent instance to update with the new model upon success. Returns: The response from the fallback model invocation. Raises: Exception: If all function-calling fallback models fail. """ logger.debug("Attempting function-calling fallback using fallback models") failed_tool_call_name = code for fallback_model in self.fallback_tool_models: try: logger.debug(f"Trying fallback model: {fallback_model['model']}") simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name) binded_model = simple_model.bind_tools( [tool_to_bind], tool_choice=failed_tool_call_name ) retry_model = binded_model.with_retry( stop_after_attempt=RETRY_FALLBACK_COUNT ) response = retry_model.invoke(code) cpm(f"response={response}") self.tool_failure_used_fallbacks.add(fallback_model["model"]) self.reset_fallback_handler() logger.debug( "Function-calling fallback executed successfully with model: " + fallback_model["model"] ) cpm( response.content if hasattr(response, "content") else response, title="Fallback Model Response: " + fallback_model["model"], ) return response except Exception as e: if isinstance(e, KeyboardInterrupt): raise logger.error( f"Function-calling fallback with model {fallback_model['model']} failed: {e}" ) raise Exception("All function-calling fallback models failed") def invoke_prompt_tool_call(self, tool_call_request: dict): """ Invoke a tool call from a prompt-based fallback response. Args: tool_call_request (dict): The tool call request containing keys 'type', 'name', and 'arguments'. Returns: The result of invoking the tool. """ tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools} name = tool_call_request["name"] arguments = tool_call_request["arguments"] # return tool_name_to_tool[name].invoke(arguments) # tool_call_dict = {"arguments": arguments} return tool_name_to_tool[name].invoke(arguments) def base_message_to_tool_call_dict(self, response: BaseMessage): """ Extracts a tool call dictionary from a fallback response. Args: response: The response object containing tool call data. Returns: A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found, otherwise None. """ tool_calls = None if hasattr(response, "additional_kwargs") and response.additional_kwargs.get( "tool_calls" ): tool_calls = response.additional_kwargs.get("tool_calls") elif hasattr(response, "tool_calls"): tool_calls = response.tool_calls elif isinstance(response, dict) and response.get("additional_kwargs", {}).get( "tool_calls" ): tool_calls = response.get("additional_kwargs").get("tool_calls") if tool_calls: if len(tool_calls) > 1: logger.warning("Multiple tool calls detected, using the first one") tool_call = tool_calls[0] return { "id": tool_call["id"], "type": tool_call["type"], "name": tool_call["function"]["name"], "arguments": ( json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"] ), } return None