diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index ef5c73d..ebc3cb6 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -1,4 +1,8 @@ -from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, RETRY_FALLBACK_DELAY +from ra_aid.config import ( + DEFAULT_MAX_TOOL_FAILURES, + FALLBACK_TOOL_MODEL_LIMIT, + RETRY_FALLBACK_COUNT, +) from ra_aid.tool_leaderboard import supported_top_tool_models from rich.console import Console from rich.markdown import Markdown @@ -9,12 +13,13 @@ from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env class FallbackHandler: """ FallbackHandler manages fallback logic when tool execution fails. - + It loads fallback models from configuration and validated provider settings, maintains failure counts, and triggers appropriate fallback methods for both prompt-based and function-calling tool invocations. It also resets internal counters when a tool call succeeds. """ + def __init__(self, config): """ Initialize the FallbackHandler with the given configuration. @@ -46,7 +51,9 @@ class FallbackHandler: if fallback_tool_models_config: # Assume comma-separated model names; wrap each in a dict with default type "prompt" models = [] - for m in [x.strip() for x in fallback_tool_models_config.split(",") if x.strip()]: + for m in [ + x.strip() for x in fallback_tool_models_config.split(",") if x.strip() + ]: models.append({"model": m, "type": "prompt"}) return models else: @@ -62,8 +69,14 @@ class FallbackHandler: break else: skipped.append(model_name) - final_models = supported # list of dicts - message = "Fallback models selected: " + ", ".join([m["model"] for m in final_models]) + final_models = [] + for item in supported: + if "type" not in item: + item["type"] = "prompt" + final_models.append(item) + message = "Fallback models selected: " + ", ".join( + [m["model"] for m in final_models] + ) if skipped: message += ( "\nSkipped top tool calling models due to missing provider ENV API keys: " @@ -115,7 +128,14 @@ class FallbackHandler: logger.error( f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" ) - Console().print(Panel(Markdown(f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."), title="Fallback Notification")) + Console().print( + Panel( + Markdown( + f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}." + ), + title="Fallback Notification", + ) + ) if fallback_model.get("type", "prompt").lower() == "fc": self.attempt_fallback_function(code, logger, agent) else: @@ -127,6 +147,7 @@ class FallbackHandler: """ self.tool_failure_consecutive_failures = 0 self.tool_failure_used_fallbacks.clear() + def attempt_fallback_prompt(self, code: str, logger, agent): """ Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. @@ -149,16 +170,30 @@ class FallbackHandler: for fallback_model in self.fallback_tool_models: try: logger.debug(f"Trying fallback model: {fallback_model['model']}") - model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) - model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) - response = model.invoke(code) - self.tool_failure_used_fallbacks.add(fallback_model['model']) - agent.model = model + simple_model = initialize_llm( + fallback_model["provider"], fallback_model["model"] + ) + binded_model = simple_model.bind_tools( + agent.tools, tool_choice=failed_tool_call_name + ) + retry_model = binded_model.with_retry( + stop_after_attempt=RETRY_FALLBACK_COUNT + ) + response = retry_model.invoke(code) + self.tool_failure_used_fallbacks.add(fallback_model["model"]) + agent.model = retry_model self.reset_fallback_handler() - logger.debug("Prompt-based fallback executed successfully with model: " + fallback_model['model']) + logger.debug( + "Prompt-based fallback executed successfully with model: " + + fallback_model["model"] + ) return response except Exception as e: - logger.error(f"Prompt-based fallback with model {fallback_model['model']} failed: {e}") + if isinstance(e, KeyboardInterrupt): + raise + logger.error( + f"Prompt-based fallback with model {fallback_model['model']} failed: {e}" + ) raise Exception("All prompt-based fallback models failed") def attempt_fallback_function(self, code: str, logger, agent): @@ -183,14 +218,28 @@ class FallbackHandler: for fallback_model in self.fallback_tool_models: try: logger.debug(f"Trying fallback model: {fallback_model['model']}") - model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) - model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) - response = model.invoke(code) - self.tool_failure_used_fallbacks.add(fallback_model['model']) - agent.model = model + simple_model = initialize_llm( + fallback_model["provider"], fallback_model["model"] + ) + binded_model = simple_model.bind_tools( + agent.tools, tool_choice=failed_tool_call_name + ) + retry_model = binded_model.with_retry( + stop_after_attempt=RETRY_FALLBACK_COUNT + ) + response = retry_model.invoke(code) + self.tool_failure_used_fallbacks.add(fallback_model["model"]) + agent.model = retry_model self.reset_fallback_handler() - logger.debug("Function-calling fallback executed successfully with model: " + fallback_model['model']) + logger.debug( + "Function-calling fallback executed successfully with model: " + + fallback_model["model"] + ) return response except Exception as e: - logger.error(f"Function-calling fallback with model {fallback_model['model']} failed: {e}") + if isinstance(e, KeyboardInterrupt): + raise + logger.error( + f"Function-calling fallback with model {fallback_model['model']} failed: {e}" + ) raise Exception("All function-calling fallback models failed")