docs(fallback_handler.py): add detailed docstrings to FallbackHandler methods to improve code documentation and clarity on functionality

This commit is contained in:
Ariel Frischer 2025-02-11 00:44:39 -08:00
parent 3d622911a6
commit d39be05e39
1 changed files with 79 additions and 0 deletions

View File

@ -7,7 +7,21 @@ from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
class FallbackHandler:
"""
FallbackHandler manages fallback logic when tool execution fails.
It loads fallback models from configuration and validated provider settings,
maintains failure counts, and triggers appropriate fallback methods for both
prompt-based and function-calling tool invocations. It also resets internal
counters when a tool call succeeds.
"""
def __init__(self, config):
"""
Initialize the FallbackHandler with the given configuration.
Args:
config (dict): Configuration dictionary that may include fallback settings.
"""
self.config = config
self.fallback_enabled = config.get("fallback_tool_enabled", True)
self.fallback_tool_models = self._load_fallback_tool_models(config)
@ -15,6 +29,19 @@ class FallbackHandler:
self.tool_failure_used_fallbacks = set()
def _load_fallback_tool_models(self, config):
"""
Load and return fallback tool models based on the provided configuration.
If the config specifies 'fallback_tool_models', those are used (assuming comma-separated names).
Otherwise, this method filters the supported_top_tool_models based on provider environment validation,
selecting up to FALLBACK_TOOL_MODEL_LIMIT models.
Args:
config (dict): Configuration dictionary.
Returns:
list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model.
"""
fallback_tool_models_config = config.get("fallback_tool_models")
if fallback_tool_models_config:
# Assume comma-separated model names; wrap each in a dict with default type "prompt"
@ -46,6 +73,15 @@ class FallbackHandler:
return final_models
def handle_failure(self, code: str, error: Exception, logger, agent):
"""
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
Args:
code (str): The code that failed to execute.
error (Exception): The exception raised during execution.
logger: Logger instance for logging.
agent: The agent instance on which fallback may be executed.
"""
logger.debug(
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
)
@ -65,6 +101,14 @@ class FallbackHandler:
self.attempt_fallback(code, logger, agent)
def attempt_fallback(self, code: str, logger, agent):
"""
Initiate the fallback process by selecting a fallback model and triggering the appropriate fallback method.
Args:
code (str): The tool code that triggered the fallback.
logger: Logger instance for logging messages.
agent: The agent for which fallback is being executed.
"""
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
fallback_model = self.fallback_tool_models[0]
failed_tool_call_name = code.split("(")[0].strip()
@ -78,9 +122,28 @@ class FallbackHandler:
self.attempt_fallback_prompt(code, logger, agent)
def reset_fallback_handler(self):
"""
Reset the fallback handler's internal failure counters and clear the record of used fallback models.
"""
self.tool_failure_consecutive_failures = 0
self.tool_failure_used_fallbacks.clear()
def attempt_fallback_prompt(self, code: str, logger, agent):
"""
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
Args:
code (str): The tool code to invoke via fallback.
logger: Logger instance for logging messages.
agent: The agent instance to update with the new model upon success.
Returns:
The response from the fallback model invocation.
Raises:
Exception: If all prompt-based fallback models fail.
"""
logger.debug("Attempting prompt-based fallback using fallback models")
failed_tool_call_name = code.split("(")[0].strip()
for fallback_model in self.fallback_tool_models:
@ -99,6 +162,22 @@ class FallbackHandler:
raise Exception("All prompt-based fallback models failed")
def attempt_fallback_function(self, code: str, logger, agent):
"""
Attempt a function-calling fallback by iterating over fallback models and invoking the provided code.
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
Args:
code (str): The tool code to invoke via fallback.
logger: Logger instance for logging messages.
agent: The agent instance to update with the new model upon success.
Returns:
The response from the fallback model invocation.
Raises:
Exception: If all function-calling fallback models fail.
"""
logger.debug("Attempting function-calling fallback using fallback models")
failed_tool_call_name = code.split("(")[0].strip()
for fallback_model in self.fallback_tool_models: