refactor(fallback_handler.py): improve code readability by formatting imports and restructuring for loops

fix(fallback_handler.py): ensure fallback models have a default type of "prompt" and handle exceptions properly during fallback attempts
This commit is contained in:
Ariel Frischer 2025-02-11 01:10:22 -08:00
parent d39be05e39
commit de489584e5
1 changed files with 69 additions and 20 deletions

View File

@ -1,4 +1,8 @@
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, RETRY_FALLBACK_DELAY from ra_aid.config import (
DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT,
)
from ra_aid.tool_leaderboard import supported_top_tool_models from ra_aid.tool_leaderboard import supported_top_tool_models
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
@ -15,6 +19,7 @@ class FallbackHandler:
prompt-based and function-calling tool invocations. It also resets internal prompt-based and function-calling tool invocations. It also resets internal
counters when a tool call succeeds. counters when a tool call succeeds.
""" """
def __init__(self, config): def __init__(self, config):
""" """
Initialize the FallbackHandler with the given configuration. Initialize the FallbackHandler with the given configuration.
@ -46,7 +51,9 @@ class FallbackHandler:
if fallback_tool_models_config: if fallback_tool_models_config:
# Assume comma-separated model names; wrap each in a dict with default type "prompt" # Assume comma-separated model names; wrap each in a dict with default type "prompt"
models = [] models = []
for m in [x.strip() for x in fallback_tool_models_config.split(",") if x.strip()]: for m in [
x.strip() for x in fallback_tool_models_config.split(",") if x.strip()
]:
models.append({"model": m, "type": "prompt"}) models.append({"model": m, "type": "prompt"})
return models return models
else: else:
@ -62,8 +69,14 @@ class FallbackHandler:
break break
else: else:
skipped.append(model_name) skipped.append(model_name)
final_models = supported # list of dicts final_models = []
message = "Fallback models selected: " + ", ".join([m["model"] for m in final_models]) for item in supported:
if "type" not in item:
item["type"] = "prompt"
final_models.append(item)
message = "Fallback models selected: " + ", ".join(
[m["model"] for m in final_models]
)
if skipped: if skipped:
message += ( message += (
"\nSkipped top tool calling models due to missing provider ENV API keys: " "\nSkipped top tool calling models due to missing provider ENV API keys: "
@ -115,7 +128,14 @@ class FallbackHandler:
logger.error( logger.error(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
) )
Console().print(Panel(Markdown(f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."), title="Fallback Notification")) Console().print(
Panel(
Markdown(
f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."
),
title="Fallback Notification",
)
)
if fallback_model.get("type", "prompt").lower() == "fc": if fallback_model.get("type", "prompt").lower() == "fc":
self.attempt_fallback_function(code, logger, agent) self.attempt_fallback_function(code, logger, agent)
else: else:
@ -127,6 +147,7 @@ class FallbackHandler:
""" """
self.tool_failure_consecutive_failures = 0 self.tool_failure_consecutive_failures = 0
self.tool_failure_used_fallbacks.clear() self.tool_failure_used_fallbacks.clear()
def attempt_fallback_prompt(self, code: str, logger, agent): def attempt_fallback_prompt(self, code: str, logger, agent):
""" """
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
@ -149,16 +170,30 @@ class FallbackHandler:
for fallback_model in self.fallback_tool_models: for fallback_model in self.fallback_tool_models:
try: try:
logger.debug(f"Trying fallback model: {fallback_model['model']}") logger.debug(f"Trying fallback model: {fallback_model['model']}")
model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) simple_model = initialize_llm(
model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) fallback_model["provider"], fallback_model["model"]
response = model.invoke(code) )
self.tool_failure_used_fallbacks.add(fallback_model['model']) binded_model = simple_model.bind_tools(
agent.model = model agent.tools, tool_choice=failed_tool_call_name
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(code)
self.tool_failure_used_fallbacks.add(fallback_model["model"])
agent.model = retry_model
self.reset_fallback_handler() self.reset_fallback_handler()
logger.debug("Prompt-based fallback executed successfully with model: " + fallback_model['model']) logger.debug(
"Prompt-based fallback executed successfully with model: "
+ fallback_model["model"]
)
return response return response
except Exception as e: except Exception as e:
logger.error(f"Prompt-based fallback with model {fallback_model['model']} failed: {e}") if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
)
raise Exception("All prompt-based fallback models failed") raise Exception("All prompt-based fallback models failed")
def attempt_fallback_function(self, code: str, logger, agent): def attempt_fallback_function(self, code: str, logger, agent):
@ -183,14 +218,28 @@ class FallbackHandler:
for fallback_model in self.fallback_tool_models: for fallback_model in self.fallback_tool_models:
try: try:
logger.debug(f"Trying fallback model: {fallback_model['model']}") logger.debug(f"Trying fallback model: {fallback_model['model']}")
model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) simple_model = initialize_llm(
model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) fallback_model["provider"], fallback_model["model"]
response = model.invoke(code) )
self.tool_failure_used_fallbacks.add(fallback_model['model']) binded_model = simple_model.bind_tools(
agent.model = model agent.tools, tool_choice=failed_tool_call_name
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(code)
self.tool_failure_used_fallbacks.add(fallback_model["model"])
agent.model = retry_model
self.reset_fallback_handler() self.reset_fallback_handler()
logger.debug("Function-calling fallback executed successfully with model: " + fallback_model['model']) logger.debug(
"Function-calling fallback executed successfully with model: "
+ fallback_model["model"]
)
return response return response
except Exception as e: except Exception as e:
logger.error(f"Function-calling fallback with model {fallback_model['model']} failed: {e}") if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
)
raise Exception("All function-calling fallback models failed") raise Exception("All function-calling fallback models failed")