docs(fallback_handler.py): add detailed docstrings to FallbackHandler methods to improve code documentation and clarity on functionality
This commit is contained in:
parent
3d622911a6
commit
d39be05e39
|
|
@ -7,7 +7,21 @@ from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
|
|||
|
||||
|
||||
class FallbackHandler:
|
||||
"""
|
||||
FallbackHandler manages fallback logic when tool execution fails.
|
||||
|
||||
It loads fallback models from configuration and validated provider settings,
|
||||
maintains failure counts, and triggers appropriate fallback methods for both
|
||||
prompt-based and function-calling tool invocations. It also resets internal
|
||||
counters when a tool call succeeds.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Initialize the FallbackHandler with the given configuration.
|
||||
|
||||
Args:
|
||||
config (dict): Configuration dictionary that may include fallback settings.
|
||||
"""
|
||||
self.config = config
|
||||
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
||||
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
||||
|
|
@ -15,6 +29,19 @@ class FallbackHandler:
|
|||
self.tool_failure_used_fallbacks = set()
|
||||
|
||||
def _load_fallback_tool_models(self, config):
|
||||
"""
|
||||
Load and return fallback tool models based on the provided configuration.
|
||||
|
||||
If the config specifies 'fallback_tool_models', those are used (assuming comma-separated names).
|
||||
Otherwise, this method filters the supported_top_tool_models based on provider environment validation,
|
||||
selecting up to FALLBACK_TOOL_MODEL_LIMIT models.
|
||||
|
||||
Args:
|
||||
config (dict): Configuration dictionary.
|
||||
|
||||
Returns:
|
||||
list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model.
|
||||
"""
|
||||
fallback_tool_models_config = config.get("fallback_tool_models")
|
||||
if fallback_tool_models_config:
|
||||
# Assume comma-separated model names; wrap each in a dict with default type "prompt"
|
||||
|
|
@ -46,6 +73,15 @@ class FallbackHandler:
|
|||
return final_models
|
||||
|
||||
def handle_failure(self, code: str, error: Exception, logger, agent):
|
||||
"""
|
||||
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
||||
|
||||
Args:
|
||||
code (str): The code that failed to execute.
|
||||
error (Exception): The exception raised during execution.
|
||||
logger: Logger instance for logging.
|
||||
agent: The agent instance on which fallback may be executed.
|
||||
"""
|
||||
logger.debug(
|
||||
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
|
||||
)
|
||||
|
|
@ -65,6 +101,14 @@ class FallbackHandler:
|
|||
self.attempt_fallback(code, logger, agent)
|
||||
|
||||
def attempt_fallback(self, code: str, logger, agent):
|
||||
"""
|
||||
Initiate the fallback process by selecting a fallback model and triggering the appropriate fallback method.
|
||||
|
||||
Args:
|
||||
code (str): The tool code that triggered the fallback.
|
||||
logger: Logger instance for logging messages.
|
||||
agent: The agent for which fallback is being executed.
|
||||
"""
|
||||
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
|
||||
fallback_model = self.fallback_tool_models[0]
|
||||
failed_tool_call_name = code.split("(")[0].strip()
|
||||
|
|
@ -78,9 +122,28 @@ class FallbackHandler:
|
|||
self.attempt_fallback_prompt(code, logger, agent)
|
||||
|
||||
def reset_fallback_handler(self):
|
||||
"""
|
||||
Reset the fallback handler's internal failure counters and clear the record of used fallback models.
|
||||
"""
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.tool_failure_used_fallbacks.clear()
|
||||
def attempt_fallback_prompt(self, code: str, logger, agent):
|
||||
"""
|
||||
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
|
||||
|
||||
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
|
||||
|
||||
Args:
|
||||
code (str): The tool code to invoke via fallback.
|
||||
logger: Logger instance for logging messages.
|
||||
agent: The agent instance to update with the new model upon success.
|
||||
|
||||
Returns:
|
||||
The response from the fallback model invocation.
|
||||
|
||||
Raises:
|
||||
Exception: If all prompt-based fallback models fail.
|
||||
"""
|
||||
logger.debug("Attempting prompt-based fallback using fallback models")
|
||||
failed_tool_call_name = code.split("(")[0].strip()
|
||||
for fallback_model in self.fallback_tool_models:
|
||||
|
|
@ -99,6 +162,22 @@ class FallbackHandler:
|
|||
raise Exception("All prompt-based fallback models failed")
|
||||
|
||||
def attempt_fallback_function(self, code: str, logger, agent):
|
||||
"""
|
||||
Attempt a function-calling fallback by iterating over fallback models and invoking the provided code.
|
||||
|
||||
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
|
||||
|
||||
Args:
|
||||
code (str): The tool code to invoke via fallback.
|
||||
logger: Logger instance for logging messages.
|
||||
agent: The agent instance to update with the new model upon success.
|
||||
|
||||
Returns:
|
||||
The response from the fallback model invocation.
|
||||
|
||||
Raises:
|
||||
Exception: If all function-calling fallback models fail.
|
||||
"""
|
||||
logger.debug("Attempting function-calling fallback using fallback models")
|
||||
failed_tool_call_name = code.split("(")[0].strip()
|
||||
for fallback_model in self.fallback_tool_models:
|
||||
|
|
|
|||
Loading…
Reference in New Issue