RA.Aid/ra_aid/fallback_handler.py

404 lines
16 KiB
Python

import json
import re
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langgraph.graph.message import BaseMessage
from ra_aid.agents_alias import RAgents
from ra_aid.config import (
DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT,
)
from ra_aid.console.output import cpm
from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError
from ra_aid.llm import initialize_llm, validate_provider_env
from ra_aid.logging_config import get_logger
from ra_aid.tool_configs import get_all_tools
from ra_aid.tool_leaderboard import supported_top_tool_models
logger = get_logger(__name__)
class FallbackHandler:
"""
FallbackHandler manages fallback logic when tool execution fails.
It loads fallback models from configuration and validated provider settings,
maintains failure counts, and triggers appropriate fallback methods for both
prompt-based and function-calling tool invocations. It also resets internal
counters when a tool call succeeds.
"""
def __init__(self, config, tools):
"""
Initialize the FallbackHandler with the given configuration and tools.
Args:
config (dict): Configuration dictionary that may include fallback settings.
tools (list): List of available tools.
"""
self.config = config
self.tools: list[BaseTool] = tools
self.fallback_enabled = config.get("experimental_fallback_handler", False)
self.fallback_tool_models = self._load_fallback_tool_models(config)
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
self.tool_failure_consecutive_failures = 0
self.failed_messages: list[BaseMessage] = []
self.current_failing_tool_name = ""
self.current_tool_to_bind: None | BaseTool = None
cpm(
"Fallback models selected: "
+ ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
title="Fallback Models",
)
def _format_model(self, m: dict) -> str:
return f"{m.get('model', '')} ({m.get('type', 'prompt')})"
def _load_fallback_tool_models(self, _config):
"""
Load and return fallback tool models based on the provided configuration.
If the config specifies 'fallback_tool_models', those are used (assuming comma-separated names).
Otherwise, this method filters the supported_top_tool_models based on provider environment validation,
selecting up to FALLBACK_TOOL_MODEL_LIMIT models.
Args:
config (dict): Configuration dictionary.
Returns:
list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model.
"""
supported = []
skipped = []
for item in supported_top_tool_models:
provider = item.get("provider")
model_name = item.get("model")
if validate_provider_env(provider):
supported.append(item)
if len(supported) == FALLBACK_TOOL_MODEL_LIMIT:
break
else:
skipped.append(model_name)
final_models = []
for item in supported:
if "type" not in item:
item["type"] = "prompt"
item["model"] = item["model"].lower()
final_models.append(item)
message = "Fallback models selected: " + ", ".join(
[m["model"] for m in final_models]
)
if skipped:
message += (
"\nSkipped top tool calling models due to missing provider ENV API keys: "
+ ", ".join(skipped)
)
return final_models
def handle_failure(self, error: ToolExecutionError, agent: RAgents):
"""
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
Args:
error (Exception): The exception raised during execution. If the exception has a 'base_message' attribute, that message is recorded.
agent: The agent instance on which fallback may be executed.
"""
if not self.fallback_enabled:
return None
failed_tool_call_name = self.extract_failed_tool_name(error)
self._reset_on_new_failure(failed_tool_call_name)
logger.debug(
f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}"
)
self.current_failing_tool_name = failed_tool_call_name
self.current_tool_to_bind = self._find_tool_to_bind(
agent, failed_tool_call_name
)
if hasattr(error, "base_message") and error.base_message:
self.failed_messages.append(error.base_message)
self.tool_failure_consecutive_failures += 1
logger.debug(
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {self.max_failures}"
)
if (
self.tool_failure_consecutive_failures >= self.max_failures
and self.fallback_tool_models
):
logger.debug(
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
)
return self.attempt_fallback()
def attempt_fallback(self):
"""
Initiate the fallback process by iterating over all fallback models to attempt to fix the failing tool call.
Returns:
List of [raw_llm_response (SystemMessage), tool_call_result (SystemMessage)] or None.
"""
logger.error(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
)
cpm(
f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
title="Fallback Notification",
)
for fallback_model in self.fallback_tool_models:
result_list = self.invoke_fallback(fallback_model)
if result_list:
return result_list
cpm("All fallback models have failed.", title="Fallback Failed")
current_failing_tool_name = self.current_failing_tool_name
self.reset_fallback_handler()
raise FallbackToolExecutionError(
f"All fallback models have failed for tool: {current_failing_tool_name}"
)
def reset_fallback_handler(self):
"""
Reset the fallback handler's internal failure counters and clear the record of used fallback models.
"""
self.tool_failure_consecutive_failures = 0
self.failed_messages.clear()
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
self.current_failing_tool_name = ""
self.current_tool_to_bind = None
def _reset_on_new_failure(self, failed_tool_call_name):
if (
self.current_failing_tool_name
and failed_tool_call_name != self.current_failing_tool_name
):
logger.debug(
"New failing tool name identified. Resetting consecutive tool failures."
)
self.reset_fallback_handler()
def extract_failed_tool_name(self, error: ToolExecutionError):
if error.tool_name:
failed_tool_call_name = error.tool_name
else:
msg = str(error)
logger.debug("Error message: %s", msg)
match = re.search(r"name=['\"](\w+)['\"]", msg)
if match:
failed_tool_call_name = str(match.group(1))
logger.debug(
"Extracted failed_tool_call_name using regex: %s",
failed_tool_call_name,
)
else:
failed_tool_call_name = "Tool execution error"
raise FallbackToolExecutionError(
"Fallback failed: Could not extract failed tool name."
)
return failed_tool_call_name
def _find_tool_to_bind(self, agent, failed_tool_call_name):
logger.debug(f"failed_tool_call_name={failed_tool_call_name}")
tool_to_bind = None
if hasattr(agent, "tools"):
tool_to_bind = next(
(t for t in agent.tools if t.func.__name__ == failed_tool_call_name),
None,
)
if tool_to_bind is None:
all_tools = get_all_tools()
tool_to_bind = next(
(t for t in all_tools if t.func.__name__ == failed_tool_call_name),
None,
)
if tool_to_bind is None:
# TODO: Would be nice to try fuzzy match or levenstein str match to find closest correspond tool name
raise FallbackToolExecutionError(
f"Fallback failed failed_tool_call_name: '{failed_tool_call_name}' not found in any available tools."
)
return tool_to_bind
def _bind_tool_model(self, simple_model: BaseChatModel, fallback_model):
if fallback_model.get("type", "prompt").lower() == "fc":
# Force tool calling with tool_choice param.
bound_model = simple_model.bind_tools(
[self.current_tool_to_bind],
tool_choice=self.current_failing_tool_name,
)
else:
# Do not force tool calling (Prompt method)
bound_model = simple_model.bind_tools([self.current_tool_to_bind])
return bound_model
def invoke_fallback(self, fallback_model):
"""
Attempt a Prompt or function-calling fallback by invoking the current failing tool with the given fallback model.
Args:
fallback_model (dict): The fallback model to use.
Returns:
The response from the fallback model invocation, or None if failed.
"""
try:
logger.debug(f"Trying fallback model: {self._format_model(fallback_model)}")
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
bound_model = self._bind_tool_model(simple_model, fallback_model)
retry_model = bound_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
msg_list = self.construct_prompt_msg_list()
response = retry_model.invoke(msg_list)
logger.debug(f"raw llm response={response}")
tool_call = self.base_message_to_tool_call_dict(response)
tool_call_result = self.invoke_prompt_tool_call(tool_call)
cpm(str(tool_call_result), title="Fallback Tool Call Result")
logger.debug(
f"Fallback call successful with model: {self._format_model(fallback_model)}"
)
self.reset_fallback_handler()
return [response, tool_call_result]
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Fallback with model {self._format_model(fallback_model)} failed: {e}"
)
return None
def construct_prompt_msg_list(self):
"""
Construct a list of chat messages for the fallback prompt.
The initial message instructs the assistant that it is a fallback tool caller.
Then includes the failed tool call messages from self.failed_messages.
Finally, it appends a human message asking it to retry calling the tool with correct valid arguments.
Returns:
list: A list of chat messages.
"""
msg_list: list[BaseMessage] = []
msg_list.append(
SystemMessage(
content="You are a fallback tool caller. Your only responsibility is to figure out what the previous failed tool call was trying to do and to call that tool with the correct format and arguments, using the provided failure messages."
)
)
if self.failed_messages:
# Convert to system messages to avoid API errors asking for correct msg structure
msg_list.extend([SystemMessage(str(msg)) for msg in self.failed_messages])
msg_list.append(
HumanMessage(
content=f"Retry using the tool '{self.current_failing_tool_name}' with improved arguments."
)
)
return msg_list
def invoke_prompt_tool_call(self, tool_call_request: dict):
"""
Invoke a tool call from a prompt-based fallback response.
Args:
tool_call_request (dict): The tool call request containing keys 'type', 'name', and 'arguments'.
Returns:
The result of invoking the tool.
"""
tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools}
name = tool_call_request["name"]
arguments = tool_call_request["arguments"]
return tool_name_to_tool[name].invoke(arguments)
def base_message_to_tool_call_dict(self, response: BaseMessage):
"""
Extracts a tool call dictionary from a BaseMessage.
Args:
response: The response object containing tool call data.
Returns:
A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found,
otherwise None.
"""
tool_calls = self.get_tool_calls(response)
if not tool_calls:
raise Exception(
f"Could not extract tool_call_dict from response: {response}"
)
if len(tool_calls) > 1:
logger.warning("Multiple tool calls detected, using the first one")
tool_call = tool_calls[0]
return {
"id": tool_call["id"],
"type": tool_call["type"],
"name": tool_call["function"]["name"],
"arguments": self._parse_tool_arguments(tool_call["function"]["arguments"]),
}
def _parse_tool_arguments(self, tool_arguments):
"""
Helper method to parse tool call arguments.
If tool_arguments is a string, it returns the JSON-parsed dictionary.
Otherwise, returns tool_arguments as is.
"""
if isinstance(tool_arguments, str):
return json.loads(tool_arguments)
return tool_arguments
def get_tool_calls(self, response: BaseMessage):
"""
Extracts tool calls list from a fallback response.
Args:
response: The response object containing tool call data.
Returns:
The tool calls list if present, otherwise None.
"""
tool_calls = None
if hasattr(response, "additional_kwargs") and response.additional_kwargs.get(
"tool_calls"
):
tool_calls = response.additional_kwargs.get("tool_calls")
elif hasattr(response, "tool_calls"):
tool_calls = response.tool_calls
elif isinstance(response, dict) and response.get("additional_kwargs", {}).get(
"tool_calls"
):
tool_calls = response.get("additional_kwargs").get("tool_calls")
return tool_calls
def handle_failure_response(
self, error: ToolExecutionError, agent, agent_type: str
):
"""
Handle a tool failure by calling handle_failure and, if a fallback response is returned and the agent type is "React",
return a list of SystemMessage objects wrapping each message from the fallback response.
"""
fallback_response = self.handle_failure(error, agent)
if fallback_response and agent_type == "React":
return [SystemMessage(str(msg)) for msg in fallback_response]
return None