RA.Aid/ra_aid/fallback_handler.py

197 lines
9.4 KiB
Python

from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, RETRY_FALLBACK_DELAY
from ra_aid.tool_leaderboard import supported_top_tool_models
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
class FallbackHandler:
"""
FallbackHandler manages fallback logic when tool execution fails.
It loads fallback models from configuration and validated provider settings,
maintains failure counts, and triggers appropriate fallback methods for both
prompt-based and function-calling tool invocations. It also resets internal
counters when a tool call succeeds.
"""
def __init__(self, config):
"""
Initialize the FallbackHandler with the given configuration.
Args:
config (dict): Configuration dictionary that may include fallback settings.
"""
self.config = config
self.fallback_enabled = config.get("fallback_tool_enabled", True)
self.fallback_tool_models = self._load_fallback_tool_models(config)
self.tool_failure_consecutive_failures = 0
self.tool_failure_used_fallbacks = set()
def _load_fallback_tool_models(self, config):
"""
Load and return fallback tool models based on the provided configuration.
If the config specifies 'fallback_tool_models', those are used (assuming comma-separated names).
Otherwise, this method filters the supported_top_tool_models based on provider environment validation,
selecting up to FALLBACK_TOOL_MODEL_LIMIT models.
Args:
config (dict): Configuration dictionary.
Returns:
list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model.
"""
fallback_tool_models_config = config.get("fallback_tool_models")
if fallback_tool_models_config:
# Assume comma-separated model names; wrap each in a dict with default type "prompt"
models = []
for m in [x.strip() for x in fallback_tool_models_config.split(",") if x.strip()]:
models.append({"model": m, "type": "prompt"})
return models
else:
console = Console()
supported = []
skipped = []
for item in supported_top_tool_models:
provider = item.get("provider")
model_name = item.get("model")
if validate_provider_env(provider):
supported.append(item)
if len(supported) == FALLBACK_TOOL_MODEL_LIMIT:
break
else:
skipped.append(model_name)
final_models = supported # list of dicts
message = "Fallback models selected: " + ", ".join([m["model"] for m in final_models])
if skipped:
message += (
"\nSkipped top tool calling models due to missing provider ENV API keys: "
+ ", ".join(skipped)
)
console.print(Panel(Markdown(message), title="Fallback Models"))
return final_models
def handle_failure(self, code: str, error: Exception, logger, agent):
"""
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
Args:
code (str): The code that failed to execute.
error (Exception): The exception raised during execution.
logger: Logger instance for logging.
agent: The agent instance on which fallback may be executed.
"""
logger.debug(
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
)
self.tool_failure_consecutive_failures += 1
max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
logger.debug(
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}"
)
if (
self.fallback_enabled
and self.tool_failure_consecutive_failures >= max_failures
and self.fallback_tool_models
):
logger.debug(
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
)
self.attempt_fallback(code, logger, agent)
def attempt_fallback(self, code: str, logger, agent):
"""
Initiate the fallback process by selecting a fallback model and triggering the appropriate fallback method.
Args:
code (str): The tool code that triggered the fallback.
logger: Logger instance for logging messages.
agent: The agent for which fallback is being executed.
"""
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
fallback_model = self.fallback_tool_models[0]
failed_tool_call_name = code.split("(")[0].strip()
logger.error(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
)
Console().print(Panel(Markdown(f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."), title="Fallback Notification"))
if fallback_model.get("type", "prompt").lower() == "fc":
self.attempt_fallback_function(code, logger, agent)
else:
self.attempt_fallback_prompt(code, logger, agent)
def reset_fallback_handler(self):
"""
Reset the fallback handler's internal failure counters and clear the record of used fallback models.
"""
self.tool_failure_consecutive_failures = 0
self.tool_failure_used_fallbacks.clear()
def attempt_fallback_prompt(self, code: str, logger, agent):
"""
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
Args:
code (str): The tool code to invoke via fallback.
logger: Logger instance for logging messages.
agent: The agent instance to update with the new model upon success.
Returns:
The response from the fallback model invocation.
Raises:
Exception: If all prompt-based fallback models fail.
"""
logger.debug("Attempting prompt-based fallback using fallback models")
failed_tool_call_name = code.split("(")[0].strip()
for fallback_model in self.fallback_tool_models:
try:
logger.debug(f"Trying fallback model: {fallback_model['model']}")
model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY)
model.bind_tools(agent.tools, tool_choice=failed_tool_call_name)
response = model.invoke(code)
self.tool_failure_used_fallbacks.add(fallback_model['model'])
agent.model = model
self.reset_fallback_handler()
logger.debug("Prompt-based fallback executed successfully with model: " + fallback_model['model'])
return response
except Exception as e:
logger.error(f"Prompt-based fallback with model {fallback_model['model']} failed: {e}")
raise Exception("All prompt-based fallback models failed")
def attempt_fallback_function(self, code: str, logger, agent):
"""
Attempt a function-calling fallback by iterating over fallback models and invoking the provided code.
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
Args:
code (str): The tool code to invoke via fallback.
logger: Logger instance for logging messages.
agent: The agent instance to update with the new model upon success.
Returns:
The response from the fallback model invocation.
Raises:
Exception: If all function-calling fallback models fail.
"""
logger.debug("Attempting function-calling fallback using fallback models")
failed_tool_call_name = code.split("(")[0].strip()
for fallback_model in self.fallback_tool_models:
try:
logger.debug(f"Trying fallback model: {fallback_model['model']}")
model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY)
model.bind_tools(agent.tools, tool_choice=failed_tool_call_name)
response = model.invoke(code)
self.tool_failure_used_fallbacks.add(fallback_model['model'])
agent.model = model
self.reset_fallback_handler()
logger.debug("Function-calling fallback executed successfully with model: " + fallback_model['model'])
return response
except Exception as e:
logger.error(f"Function-calling fallback with model {fallback_model['model']} failed: {e}")
raise Exception("All function-calling fallback models failed")