refactor(fallback_handler.py): improve code readability by formatting imports and restructuring for loops
fix(fallback_handler.py): ensure fallback models have a default type of "prompt" and handle exceptions properly during fallback attempts
This commit is contained in:
parent
d39be05e39
commit
de489584e5
|
|
@ -1,4 +1,8 @@
|
||||||
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, RETRY_FALLBACK_DELAY
|
from ra_aid.config import (
|
||||||
|
DEFAULT_MAX_TOOL_FAILURES,
|
||||||
|
FALLBACK_TOOL_MODEL_LIMIT,
|
||||||
|
RETRY_FALLBACK_COUNT,
|
||||||
|
)
|
||||||
from ra_aid.tool_leaderboard import supported_top_tool_models
|
from ra_aid.tool_leaderboard import supported_top_tool_models
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
@ -15,6 +19,7 @@ class FallbackHandler:
|
||||||
prompt-based and function-calling tool invocations. It also resets internal
|
prompt-based and function-calling tool invocations. It also resets internal
|
||||||
counters when a tool call succeeds.
|
counters when a tool call succeeds.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
"""
|
"""
|
||||||
Initialize the FallbackHandler with the given configuration.
|
Initialize the FallbackHandler with the given configuration.
|
||||||
|
|
@ -46,7 +51,9 @@ class FallbackHandler:
|
||||||
if fallback_tool_models_config:
|
if fallback_tool_models_config:
|
||||||
# Assume comma-separated model names; wrap each in a dict with default type "prompt"
|
# Assume comma-separated model names; wrap each in a dict with default type "prompt"
|
||||||
models = []
|
models = []
|
||||||
for m in [x.strip() for x in fallback_tool_models_config.split(",") if x.strip()]:
|
for m in [
|
||||||
|
x.strip() for x in fallback_tool_models_config.split(",") if x.strip()
|
||||||
|
]:
|
||||||
models.append({"model": m, "type": "prompt"})
|
models.append({"model": m, "type": "prompt"})
|
||||||
return models
|
return models
|
||||||
else:
|
else:
|
||||||
|
|
@ -62,8 +69,14 @@ class FallbackHandler:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
skipped.append(model_name)
|
skipped.append(model_name)
|
||||||
final_models = supported # list of dicts
|
final_models = []
|
||||||
message = "Fallback models selected: " + ", ".join([m["model"] for m in final_models])
|
for item in supported:
|
||||||
|
if "type" not in item:
|
||||||
|
item["type"] = "prompt"
|
||||||
|
final_models.append(item)
|
||||||
|
message = "Fallback models selected: " + ", ".join(
|
||||||
|
[m["model"] for m in final_models]
|
||||||
|
)
|
||||||
if skipped:
|
if skipped:
|
||||||
message += (
|
message += (
|
||||||
"\nSkipped top tool calling models due to missing provider ENV API keys: "
|
"\nSkipped top tool calling models due to missing provider ENV API keys: "
|
||||||
|
|
@ -115,7 +128,14 @@ class FallbackHandler:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
|
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
|
||||||
)
|
)
|
||||||
Console().print(Panel(Markdown(f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."), title="Fallback Notification"))
|
Console().print(
|
||||||
|
Panel(
|
||||||
|
Markdown(
|
||||||
|
f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."
|
||||||
|
),
|
||||||
|
title="Fallback Notification",
|
||||||
|
)
|
||||||
|
)
|
||||||
if fallback_model.get("type", "prompt").lower() == "fc":
|
if fallback_model.get("type", "prompt").lower() == "fc":
|
||||||
self.attempt_fallback_function(code, logger, agent)
|
self.attempt_fallback_function(code, logger, agent)
|
||||||
else:
|
else:
|
||||||
|
|
@ -127,6 +147,7 @@ class FallbackHandler:
|
||||||
"""
|
"""
|
||||||
self.tool_failure_consecutive_failures = 0
|
self.tool_failure_consecutive_failures = 0
|
||||||
self.tool_failure_used_fallbacks.clear()
|
self.tool_failure_used_fallbacks.clear()
|
||||||
|
|
||||||
def attempt_fallback_prompt(self, code: str, logger, agent):
|
def attempt_fallback_prompt(self, code: str, logger, agent):
|
||||||
"""
|
"""
|
||||||
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
|
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
|
||||||
|
|
@ -149,16 +170,30 @@ class FallbackHandler:
|
||||||
for fallback_model in self.fallback_tool_models:
|
for fallback_model in self.fallback_tool_models:
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||||
model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY)
|
simple_model = initialize_llm(
|
||||||
model.bind_tools(agent.tools, tool_choice=failed_tool_call_name)
|
fallback_model["provider"], fallback_model["model"]
|
||||||
response = model.invoke(code)
|
)
|
||||||
self.tool_failure_used_fallbacks.add(fallback_model['model'])
|
binded_model = simple_model.bind_tools(
|
||||||
agent.model = model
|
agent.tools, tool_choice=failed_tool_call_name
|
||||||
|
)
|
||||||
|
retry_model = binded_model.with_retry(
|
||||||
|
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||||
|
)
|
||||||
|
response = retry_model.invoke(code)
|
||||||
|
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||||
|
agent.model = retry_model
|
||||||
self.reset_fallback_handler()
|
self.reset_fallback_handler()
|
||||||
logger.debug("Prompt-based fallback executed successfully with model: " + fallback_model['model'])
|
logger.debug(
|
||||||
|
"Prompt-based fallback executed successfully with model: "
|
||||||
|
+ fallback_model["model"]
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Prompt-based fallback with model {fallback_model['model']} failed: {e}")
|
if isinstance(e, KeyboardInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error(
|
||||||
|
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
|
||||||
|
)
|
||||||
raise Exception("All prompt-based fallback models failed")
|
raise Exception("All prompt-based fallback models failed")
|
||||||
|
|
||||||
def attempt_fallback_function(self, code: str, logger, agent):
|
def attempt_fallback_function(self, code: str, logger, agent):
|
||||||
|
|
@ -183,14 +218,28 @@ class FallbackHandler:
|
||||||
for fallback_model in self.fallback_tool_models:
|
for fallback_model in self.fallback_tool_models:
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||||
model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY)
|
simple_model = initialize_llm(
|
||||||
model.bind_tools(agent.tools, tool_choice=failed_tool_call_name)
|
fallback_model["provider"], fallback_model["model"]
|
||||||
response = model.invoke(code)
|
)
|
||||||
self.tool_failure_used_fallbacks.add(fallback_model['model'])
|
binded_model = simple_model.bind_tools(
|
||||||
agent.model = model
|
agent.tools, tool_choice=failed_tool_call_name
|
||||||
|
)
|
||||||
|
retry_model = binded_model.with_retry(
|
||||||
|
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||||
|
)
|
||||||
|
response = retry_model.invoke(code)
|
||||||
|
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||||
|
agent.model = retry_model
|
||||||
self.reset_fallback_handler()
|
self.reset_fallback_handler()
|
||||||
logger.debug("Function-calling fallback executed successfully with model: " + fallback_model['model'])
|
logger.debug(
|
||||||
|
"Function-calling fallback executed successfully with model: "
|
||||||
|
+ fallback_model["model"]
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Function-calling fallback with model {fallback_model['model']} failed: {e}")
|
if isinstance(e, KeyboardInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error(
|
||||||
|
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
|
||||||
|
)
|
||||||
raise Exception("All function-calling fallback models failed")
|
raise Exception("All function-calling fallback models failed")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue