From 0521b3ff9ae9c7cc214cba354ee30319920b51d5 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Feb 2025 00:38:15 -0800 Subject: [PATCH] feat(config.py): add RETRY_FALLBACK_COUNT and RETRY_FALLBACK_DELAY to configure retry behavior for fallback models refactor(fallback_handler.py): enhance fallback handling logic to support both prompt-based and function-calling fallbacks with retries fix(fallback_handler.py): update fallback model selection to return dictionaries for better structure and access to model properties --- ra_aid/config.py | 2 + ra_aid/fallback_handler.py | 83 +++++++++++++++++++++++--------------- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/ra_aid/config.py b/ra_aid/config.py index 4c9bfea..e85cb12 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -4,6 +4,8 @@ DEFAULT_RECURSION_LIMIT = 100 DEFAULT_MAX_TEST_CMD_RETRIES = 3 DEFAULT_MAX_TOOL_FAILURES = 3 FALLBACK_TOOL_MODEL_LIMIT = 5 +RETRY_FALLBACK_COUNT = 3 +RETRY_FALLBACK_DELAY = 2 VALID_PROVIDERS = [ "anthropic", diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 0eef7b6..762d697 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -1,4 +1,4 @@ -from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT +from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, RETRY_FALLBACK_DELAY from ra_aid.tool_leaderboard import supported_top_tool_models from rich.console import Console from rich.markdown import Markdown @@ -17,9 +17,11 @@ class FallbackHandler: def _load_fallback_tool_models(self, config): fallback_tool_models_config = config.get("fallback_tool_models") if fallback_tool_models_config: - return [ - m.strip() for m in fallback_tool_models_config.split(",") if m.strip() - ] + # Assume comma-separated model names; wrap each in a dict with default type "prompt" + models = [] + for m in [x.strip() for x in fallback_tool_models_config.split(",") if x.strip()]: + models.append({"model": m, "type": "prompt"}) + return models else: console = Console() supported = [] @@ -28,13 +30,13 @@ class FallbackHandler: provider = item.get("provider") model_name = item.get("model") if validate_provider_env(provider): - supported.append(model_name) + supported.append(item) if len(supported) == FALLBACK_TOOL_MODEL_LIMIT: break else: skipped.append(model_name) - final_models = supported[:FALLBACK_TOOL_MODEL_LIMIT] - message = "Fallback models selected: " + ", ".join(final_models) + final_models = supported # list of dicts + message = "Fallback models selected: " + ", ".join([m["model"] for m in final_models]) if skipped: message += ( "\nSkipped top tool calling models due to missing provider ENV API keys: " @@ -64,36 +66,51 @@ class FallbackHandler: def attempt_fallback(self, code: str, logger, agent): logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") - new_model = self.fallback_tool_models[0] + fallback_model = self.fallback_tool_models[0] failed_tool_call_name = code.split("(")[0].strip() logger.error( - f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}" + f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" ) - try: - logger.debug(f"_attempt_fallback: validating provider {agent.provider}") - if not validate_provider_env(agent.provider): - logger.error( - f"Missing environment configuration for provider {agent.provider}. Cannot fallback." - ) - else: - logger.debug( - f"_attempt_fallback: initializing fallback model {new_model}" - ) - agent.model = initialize_llm(agent.provider, new_model) - logger.debug( - f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}" - ) - agent.model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) - self.tool_failure_used_fallbacks.add(new_model) - logger.debug("_attempt_fallback: merging chat history for fallback") - merge_chat_history() - self.tool_failure_consecutive_failures = 0 - logger.debug( - "_attempt_fallback: fallback successful and tool failure counter reset" - ) - except Exception as switch_e: - logger.error(f"Fallback model switching failed: {switch_e}") + if fallback_model.get("type", "prompt").lower() == "fc": + self.attempt_fallback_function(code, logger, agent) + else: + self.attempt_fallback_prompt(code, logger, agent) def reset_fallback_handler(self): self.tool_failure_consecutive_failures = 0 self.tool_failure_used_fallbacks.clear() + def attempt_fallback_prompt(self, code: str, logger, agent): + logger.debug("Attempting prompt-based fallback using fallback models") + failed_tool_call_name = code.split("(")[0].strip() + for fallback_model in self.fallback_tool_models: + try: + logger.debug(f"Trying fallback model: {fallback_model['model']}") + model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) + model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) + response = model.invoke(code) + self.tool_failure_used_fallbacks.add(fallback_model['model']) + agent.model = model + self.reset_fallback_handler() + logger.debug("Prompt-based fallback executed successfully with model: " + fallback_model['model']) + return response + except Exception as e: + logger.error(f"Prompt-based fallback with model {fallback_model['model']} failed: {e}") + raise Exception("All prompt-based fallback models failed") + + def attempt_fallback_function(self, code: str, logger, agent): + logger.debug("Attempting function-calling fallback using fallback models") + failed_tool_call_name = code.split("(")[0].strip() + for fallback_model in self.fallback_tool_models: + try: + logger.debug(f"Trying fallback model: {fallback_model['model']}") + model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY) + model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) + response = model.invoke(code) + self.tool_failure_used_fallbacks.add(fallback_model['model']) + agent.model = model + self.reset_fallback_handler() + logger.debug("Function-calling fallback executed successfully with model: " + fallback_model['model']) + return response + except Exception as e: + logger.error(f"Function-calling fallback with model {fallback_model['model']} failed: {e}") + raise Exception("All function-calling fallback models failed")