refactor(fallback_handler.py): improve code readability by formatting imports and restructuring for loops

fix(fallback_handler.py): ensure fallback models have a default type of "prompt" and handle exceptions properly during fallback attempts
This commit is contained in:
Ariel Frischer 2025-02-11 01:10:22 -08:00
parent d39be05e39
commit de489584e5
1 changed files with 69 additions and 20 deletions

View File

@ -1,4 +1,8 @@
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, RETRY_FALLBACK_DELAY from ra_aid.config import (
DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT,
)
from ra_aid.tool_leaderboard import supported_top_tool_models from ra_aid.tool_leaderboard import supported_top_tool_models
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
@ -9,12 +13,13 @@ from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
class FallbackHandler: class FallbackHandler:
""" """
FallbackHandler manages fallback logic when tool execution fails. FallbackHandler manages fallback logic when tool execution fails.
It loads fallback models from configuration and validated provider settings, It loads fallback models from configuration and validated provider settings,
maintains failure counts, and triggers appropriate fallback methods for both maintains failure counts, and triggers appropriate fallback methods for both
prompt-based and function-calling tool invocations. It also resets internal prompt-based and function-calling tool invocations. It also resets internal
counters when a tool call succeeds. counters when a tool call succeeds.
""" """
def __init__(self, config): def __init__(self, config):
""" """
Initialize the FallbackHandler with the given configuration. Initialize the FallbackHandler with the given configuration.
@ -46,7 +51,9 @@ class FallbackHandler:
if fallback_tool_models_config: if fallback_tool_models_config:
# Assume comma-separated model names; wrap each in a dict with default type "prompt" # Assume comma-separated model names; wrap each in a dict with default type "prompt"
models = [] models = []
for m in [x.strip() for x in fallback_tool_models_config.split(",") if x.strip()]: for m in [
x.strip() for x in fallback_tool_models_config.split(",") if x.strip()
]:
models.append({"model": m, "type": "prompt"}) models.append({"model": m, "type": "prompt"})
return models return models
else: else:
@ -62,8 +69,14 @@ class FallbackHandler:
break break
else: else:
skipped.append(model_name) skipped.append(model_name)
final_models = supported # list of dicts final_models = []
message = "Fallback models selected: " + ", ".join([m["model"] for m in final_models]) for item in supported:
if "type" not in item:
item["type"] = "prompt"
final_models.append(item)
message = "Fallback models selected: " + ", ".join(
[m["model"] for m in final_models]
)
if skipped: if skipped:
message += ( message += (
"\nSkipped top tool calling models due to missing provider ENV API keys: " "\nSkipped top tool calling models due to missing provider ENV API keys: "
@ -115,7 +128,14 @@ class FallbackHandler:
logger.error( logger.error(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
) )
Console().print(Panel(Markdown(f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."), title="Fallback Notification")) Console().print(
Panel(
Markdown(
f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."
),
title="Fallback Notification",
)
)
if fallback_model.get("type", "prompt").lower() == "fc": if fallback_model.get("type", "prompt").lower() == "fc":
self.attempt_fallback_function(code, logger, agent) self.attempt_fallback_function(code, logger, agent)
else: else:
@ -127,6 +147,7 @@ class FallbackHandler:
""" """
self.tool_failure_consecutive_failures = 0 self.tool_failure_consecutive_failures = 0
self.tool_failure_used_fallbacks.clear() self.tool_failure_used_fallbacks.clear()
def attempt_fallback_prompt(self, code: str, logger, agent): def attempt_fallback_prompt(self, code: str, logger, agent):
""" """
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
@ -149,16 +170,30 @@ class FallbackHandler:
for fallback_model in self.fallback_tool_models: for fallback_model in self.fallback_tool_models:
try: try:
logger.debug(f"Trying fallback model: {fallback_model['model']}") logger.debug(f"Trying fallback model: {fallback_model['model']}")
model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) simple_model = initialize_llm(
model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) fallback_model["provider"], fallback_model["model"]
response = model.invoke(code) )
self.tool_failure_used_fallbacks.add(fallback_model['model']) binded_model = simple_model.bind_tools(
agent.model = model agent.tools, tool_choice=failed_tool_call_name
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(code)
self.tool_failure_used_fallbacks.add(fallback_model["model"])
agent.model = retry_model
self.reset_fallback_handler() self.reset_fallback_handler()
logger.debug("Prompt-based fallback executed successfully with model: " + fallback_model['model']) logger.debug(
"Prompt-based fallback executed successfully with model: "
+ fallback_model["model"]
)
return response return response
except Exception as e: except Exception as e:
logger.error(f"Prompt-based fallback with model {fallback_model['model']} failed: {e}") if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
)
raise Exception("All prompt-based fallback models failed") raise Exception("All prompt-based fallback models failed")
def attempt_fallback_function(self, code: str, logger, agent): def attempt_fallback_function(self, code: str, logger, agent):
@ -183,14 +218,28 @@ class FallbackHandler:
for fallback_model in self.fallback_tool_models: for fallback_model in self.fallback_tool_models:
try: try:
logger.debug(f"Trying fallback model: {fallback_model['model']}") logger.debug(f"Trying fallback model: {fallback_model['model']}")
model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) simple_model = initialize_llm(
model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) fallback_model["provider"], fallback_model["model"]
response = model.invoke(code) )
self.tool_failure_used_fallbacks.add(fallback_model['model']) binded_model = simple_model.bind_tools(
agent.model = model agent.tools, tool_choice=failed_tool_call_name
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(code)
self.tool_failure_used_fallbacks.add(fallback_model["model"])
agent.model = retry_model
self.reset_fallback_handler() self.reset_fallback_handler()
logger.debug("Function-calling fallback executed successfully with model: " + fallback_model['model']) logger.debug(
"Function-calling fallback executed successfully with model: "
+ fallback_model["model"]
)
return response return response
except Exception as e: except Exception as e:
logger.error(f"Function-calling fallback with model {fallback_model['model']} failed: {e}") if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
)
raise Exception("All function-calling fallback models failed") raise Exception("All function-calling fallback models failed")