feat(config.py): add RETRY_FALLBACK_COUNT and RETRY_FALLBACK_DELAY
to configure retry behavior for fallback models refactor(fallback_handler.py): enhance fallback handling logic to support both prompt-based and function-calling fallbacks with retries fix(fallback_handler.py): update fallback model selection to return dictionaries for better structure and access to model properties
This commit is contained in:
parent
55abf6e5dd
commit
0521b3ff9a
|
|
@ -4,6 +4,8 @@ DEFAULT_RECURSION_LIMIT = 100
|
|||
DEFAULT_MAX_TEST_CMD_RETRIES = 3
|
||||
DEFAULT_MAX_TOOL_FAILURES = 3
|
||||
FALLBACK_TOOL_MODEL_LIMIT = 5
|
||||
RETRY_FALLBACK_COUNT = 3
|
||||
RETRY_FALLBACK_DELAY = 2
|
||||
|
||||
VALID_PROVIDERS = [
|
||||
"anthropic",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT
|
||||
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, RETRY_FALLBACK_DELAY
|
||||
from ra_aid.tool_leaderboard import supported_top_tool_models
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
|
@ -17,9 +17,11 @@ class FallbackHandler:
|
|||
def _load_fallback_tool_models(self, config):
|
||||
fallback_tool_models_config = config.get("fallback_tool_models")
|
||||
if fallback_tool_models_config:
|
||||
return [
|
||||
m.strip() for m in fallback_tool_models_config.split(",") if m.strip()
|
||||
]
|
||||
# Assume comma-separated model names; wrap each in a dict with default type "prompt"
|
||||
models = []
|
||||
for m in [x.strip() for x in fallback_tool_models_config.split(",") if x.strip()]:
|
||||
models.append({"model": m, "type": "prompt"})
|
||||
return models
|
||||
else:
|
||||
console = Console()
|
||||
supported = []
|
||||
|
|
@ -28,13 +30,13 @@ class FallbackHandler:
|
|||
provider = item.get("provider")
|
||||
model_name = item.get("model")
|
||||
if validate_provider_env(provider):
|
||||
supported.append(model_name)
|
||||
supported.append(item)
|
||||
if len(supported) == FALLBACK_TOOL_MODEL_LIMIT:
|
||||
break
|
||||
else:
|
||||
skipped.append(model_name)
|
||||
final_models = supported[:FALLBACK_TOOL_MODEL_LIMIT]
|
||||
message = "Fallback models selected: " + ", ".join(final_models)
|
||||
final_models = supported # list of dicts
|
||||
message = "Fallback models selected: " + ", ".join([m["model"] for m in final_models])
|
||||
if skipped:
|
||||
message += (
|
||||
"\nSkipped top tool calling models due to missing provider ENV API keys: "
|
||||
|
|
@ -64,36 +66,51 @@ class FallbackHandler:
|
|||
|
||||
def attempt_fallback(self, code: str, logger, agent):
|
||||
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
|
||||
new_model = self.fallback_tool_models[0]
|
||||
fallback_model = self.fallback_tool_models[0]
|
||||
failed_tool_call_name = code.split("(")[0].strip()
|
||||
logger.error(
|
||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}"
|
||||
)
|
||||
try:
|
||||
logger.debug(f"_attempt_fallback: validating provider {agent.provider}")
|
||||
if not validate_provider_env(agent.provider):
|
||||
logger.error(
|
||||
f"Missing environment configuration for provider {agent.provider}. Cannot fallback."
|
||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
|
||||
)
|
||||
if fallback_model.get("type", "prompt").lower() == "fc":
|
||||
self.attempt_fallback_function(code, logger, agent)
|
||||
else:
|
||||
logger.debug(
|
||||
f"_attempt_fallback: initializing fallback model {new_model}"
|
||||
)
|
||||
agent.model = initialize_llm(agent.provider, new_model)
|
||||
logger.debug(
|
||||
f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}"
|
||||
)
|
||||
agent.model.bind_tools(agent.tools, tool_choice=failed_tool_call_name)
|
||||
self.tool_failure_used_fallbacks.add(new_model)
|
||||
logger.debug("_attempt_fallback: merging chat history for fallback")
|
||||
merge_chat_history()
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
logger.debug(
|
||||
"_attempt_fallback: fallback successful and tool failure counter reset"
|
||||
)
|
||||
except Exception as switch_e:
|
||||
logger.error(f"Fallback model switching failed: {switch_e}")
|
||||
self.attempt_fallback_prompt(code, logger, agent)
|
||||
|
||||
def reset_fallback_handler(self):
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.tool_failure_used_fallbacks.clear()
|
||||
def attempt_fallback_prompt(self, code: str, logger, agent):
|
||||
logger.debug("Attempting prompt-based fallback using fallback models")
|
||||
failed_tool_call_name = code.split("(")[0].strip()
|
||||
for fallback_model in self.fallback_tool_models:
|
||||
try:
|
||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||
model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY)
|
||||
model.bind_tools(agent.tools, tool_choice=failed_tool_call_name)
|
||||
response = model.invoke(code)
|
||||
self.tool_failure_used_fallbacks.add(fallback_model['model'])
|
||||
agent.model = model
|
||||
self.reset_fallback_handler()
|
||||
logger.debug("Prompt-based fallback executed successfully with model: " + fallback_model['model'])
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Prompt-based fallback with model {fallback_model['model']} failed: {e}")
|
||||
raise Exception("All prompt-based fallback models failed")
|
||||
|
||||
def attempt_fallback_function(self, code: str, logger, agent):
|
||||
logger.debug("Attempting function-calling fallback using fallback models")
|
||||
failed_tool_call_name = code.split("(")[0].strip()
|
||||
for fallback_model in self.fallback_tool_models:
|
||||
try:
|
||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||
model = initialize_llm(agent.provider, fallback_model['model']).with_retry(retries=RETRY_FALLBACK_COUNT, delay=RETRY_FALLBACK_DELAY)
|
||||
model.bind_tools(agent.tools, tool_choice=failed_tool_call_name)
|
||||
response = model.invoke(code)
|
||||
self.tool_failure_used_fallbacks.add(fallback_model['model'])
|
||||
agent.model = model
|
||||
self.reset_fallback_handler()
|
||||
logger.debug("Function-calling fallback executed successfully with model: " + fallback_model['model'])
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Function-calling fallback with model {fallback_model['model']} failed: {e}")
|
||||
raise Exception("All function-calling fallback models failed")
|
||||
|
|
|
|||
Loading…
Reference in New Issue