feat(fallback_handler): implement FallbackHandler class to manage tool failures and fallback logic

refactor(ciayn_agent): integrate FallbackHandler into CiaynAgent for improved failure handling
fix(agent_utils): add missing newline for better readability in run_agent_with_retry function
test(fallback_handler): add unit tests for FallbackHandler to ensure correct failure handling and fallback logic
This commit is contained in:
Ariel Frischer 2025-02-10 23:37:15 -08:00
parent d8ee4e04f4
commit 55abf6e5dd
6 changed files with 164 additions and 124 deletions

View File

@ -806,6 +806,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"]:
_global_memory["plan_completed"] = False
_global_memory["task_completed"] = False

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Union
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.exceptions import ToolExecutionError
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
@ -90,8 +90,7 @@ class CiaynAgent:
config = {}
self.config = config
self.provider = config.get("provider", "openai")
self.fallback_enabled = config.get("fallback_tool_enabled", True)
self.fallback_tool_models = self._load_fallback_tool_models(config)
self.fallback_handler = FallbackHandler(config)
self.model = model
self.tools = tools
@ -100,18 +99,8 @@ class CiaynAgent:
self.available_functions = []
for t in tools:
self.available_functions.append(get_function_info(t.func))
self.tool_failure_consecutive_failures = 0
self.tool_failure_current_provider = None
self.tool_failure_current_model = None
self.tool_failure_used_fallbacks = set()
def _load_fallback_tool_models(self, config: dict) -> list:
fallback_tool_models_config = config.get("fallback_tool_models")
if fallback_tool_models_config:
return [m.strip() for m in fallback_tool_models_config.split(",") if m.strip()]
else:
from ra_aid.tool_leaderboard import supported_top_tool_models
return [item["model"] for item in supported_top_tool_models[:5]]
def _build_prompt(self, last_result: Optional[str] = None) -> str:
"""Build the prompt for the agent including available tools and context."""
@ -262,7 +251,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
logger.debug(
f"_execute_tool: tool executed successfully with result: {result}"
)
self.tool_failure_consecutive_failures = 0
self.fallback_handler.reset_fallback_handler()
return result
except Exception as e:
logger.debug(f"_execute_tool: exception caught: {e}")
@ -275,61 +264,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
)
def _handle_tool_failure(self, code: str, error: Exception) -> None:
logger.debug(
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
)
self.tool_failure_consecutive_failures += 1
max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
logger.debug(
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}"
)
if (
self.fallback_enabled
and self.tool_failure_consecutive_failures >= max_failures
and self.fallback_tool_models
):
logger.debug(
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
)
self._attempt_fallback(code)
def _attempt_fallback(self, code: str) -> None:
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
new_model = self.fallback_tool_models[0]
failed_tool_call_name = code.split("(")[0].strip()
logger.error(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}"
)
try:
from ra_aid.llm import (
initialize_llm,
merge_chat_history,
validate_provider_env,
)
logger.debug(f"_attempt_fallback: validating provider {self.provider}")
if not validate_provider_env(self.provider):
logger.error(
f"Missing environment configuration for provider {self.provider}. Cannot fallback."
)
else:
logger.debug(
f"_attempt_fallback: initializing fallback model {new_model}"
)
self.model = initialize_llm(self.provider, new_model)
logger.debug(
f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}"
)
self.model.bind_tools(self.tools, tool_choice=failed_tool_call_name)
self.tool_failure_used_fallbacks.add(new_model)
logger.debug("_attempt_fallback: merging chat history for fallback")
merge_chat_history()
self.tool_failure_consecutive_failures = 0
logger.debug(
"_attempt_fallback: fallback successful and tool failure counter reset"
)
except Exception as switch_e:
logger.error(f"Fallback model switching failed: {switch_e}")
self.fallback_handler.handle_failure(code, error, logger, self)
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
"""Create an agent chunk in the format expected by print_agent_output."""

View File

@ -3,6 +3,7 @@
DEFAULT_RECURSION_LIMIT = 100
DEFAULT_MAX_TEST_CMD_RETRIES = 3
DEFAULT_MAX_TOOL_FAILURES = 3
FALLBACK_TOOL_MODEL_LIMIT = 5
VALID_PROVIDERS = [
"anthropic",

View File

@ -0,0 +1,99 @@
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT
from ra_aid.tool_leaderboard import supported_top_tool_models
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
class FallbackHandler:
def __init__(self, config):
self.config = config
self.fallback_enabled = config.get("fallback_tool_enabled", True)
self.fallback_tool_models = self._load_fallback_tool_models(config)
self.tool_failure_consecutive_failures = 0
self.tool_failure_used_fallbacks = set()
def _load_fallback_tool_models(self, config):
fallback_tool_models_config = config.get("fallback_tool_models")
if fallback_tool_models_config:
return [
m.strip() for m in fallback_tool_models_config.split(",") if m.strip()
]
else:
console = Console()
supported = []
skipped = []
for item in supported_top_tool_models:
provider = item.get("provider")
model_name = item.get("model")
if validate_provider_env(provider):
supported.append(model_name)
if len(supported) == FALLBACK_TOOL_MODEL_LIMIT:
break
else:
skipped.append(model_name)
final_models = supported[:FALLBACK_TOOL_MODEL_LIMIT]
message = "Fallback models selected: " + ", ".join(final_models)
if skipped:
message += (
"\nSkipped top tool calling models due to missing provider ENV API keys: "
+ ", ".join(skipped)
)
console.print(Panel(Markdown(message), title="Fallback Models"))
return final_models
def handle_failure(self, code: str, error: Exception, logger, agent):
logger.debug(
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
)
self.tool_failure_consecutive_failures += 1
max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
logger.debug(
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}"
)
if (
self.fallback_enabled
and self.tool_failure_consecutive_failures >= max_failures
and self.fallback_tool_models
):
logger.debug(
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
)
self.attempt_fallback(code, logger, agent)
def attempt_fallback(self, code: str, logger, agent):
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
new_model = self.fallback_tool_models[0]
failed_tool_call_name = code.split("(")[0].strip()
logger.error(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}"
)
try:
logger.debug(f"_attempt_fallback: validating provider {agent.provider}")
if not validate_provider_env(agent.provider):
logger.error(
f"Missing environment configuration for provider {agent.provider}. Cannot fallback."
)
else:
logger.debug(
f"_attempt_fallback: initializing fallback model {new_model}"
)
agent.model = initialize_llm(agent.provider, new_model)
logger.debug(
f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}"
)
agent.model.bind_tools(agent.tools, tool_choice=failed_tool_call_name)
self.tool_failure_used_fallbacks.add(new_model)
logger.debug("_attempt_fallback: merging chat history for fallback")
merge_chat_history()
self.tool_failure_consecutive_failures = 0
logger.debug(
"_attempt_fallback: fallback successful and tool failure counter reset"
)
except Exception as switch_e:
logger.error(f"Fallback model switching failed: {switch_e}")
def reset_fallback_handler(self):
self.tool_failure_consecutive_failures = 0
self.tool_failure_used_fallbacks.clear()

View File

@ -264,61 +264,7 @@ class TestFunctionCallValidation:
class TestCiaynAgentNewMethods(unittest.TestCase):
def setUp(self):
# Create a dummy tool that always fails for testing fallback
def always_fail():
raise Exception("Failure for fallback test")
self.always_fail_tool = DummyTool(always_fail)
# Create a dummy model that does minimal work for fallback tests
self.dummy_model = DummyModel()
# Initialize CiaynAgent with configuration to trigger fallback quickly
self.agent = CiaynAgent(
self.dummy_model,
[self.always_fail_tool],
config={
"max_tool_failures": 2,
"fallback_tool_models": "dummy-fallback-model",
},
)
def test_handle_tool_failure_increments_counter(self):
initial_failures = self.agent.tool_failure_consecutive_failures
self.agent._handle_tool_failure("dummy_call()", Exception("Test error"))
self.assertEqual(
self.agent.tool_failure_consecutive_failures, initial_failures + 1
)
def test_attempt_fallback_invokes_fallback_logic(self):
# Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env
# to simulate fallback switching without external dependencies.
def dummy_initialize_llm(provider, model_name, temperature=None):
return self.dummy_model
def dummy_merge_chat_history():
return ["merged"]
def dummy_validate_provider_env(provider):
return True
import ra_aid.llm as llm
original_initialize = llm.initialize_llm
original_merge = llm.merge_chat_history
original_validate = llm.validate_provider_env
llm.initialize_llm = dummy_initialize_llm
llm.merge_chat_history = dummy_merge_chat_history
llm.validate_provider_env = dummy_validate_provider_env
# Set failure counter high enough to trigger fallback in _handle_tool_failure
self.agent.tool_failure_consecutive_failures = 2
# Call _attempt_fallback; it should reset the failure counter to 0 on success.
self.agent._attempt_fallback("always_fail_tool()")
self.assertEqual(self.agent.tool_failure_consecutive_failures, 0)
# Restore original functions
llm.initialize_llm = original_initialize
llm.merge_chat_history = original_merge
llm.validate_provider_env = original_validate
pass
if __name__ == "__main__":

View File

@ -0,0 +1,58 @@
import unittest
from ra_aid.fallback_handler import FallbackHandler
class DummyLogger:
def debug(self, msg):
pass
def error(self, msg):
pass
class DummyAgent:
provider = "openai"
tools = []
model = None
class TestFallbackHandler(unittest.TestCase):
def setUp(self):
self.config = {"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"}
self.fallback_handler = FallbackHandler(self.config)
self.logger = DummyLogger()
self.agent = DummyAgent()
def test_handle_failure_increments_counter(self):
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
self.fallback_handler.handle_failure("dummy_call()", Exception("Test error"), self.logger, self.agent)
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1)
def test_attempt_fallback_resets_counter(self):
# Monkey-patch dummy functions for fallback components
def dummy_initialize_llm(provider, model_name, temperature=None):
class DummyModel:
def bind_tools(self, tools, tool_choice):
pass
return DummyModel()
def dummy_merge_chat_history():
return ["merged"]
def dummy_validate_provider_env(provider):
return True
import ra_aid.llm as llm
original_initialize = llm.initialize_llm
original_merge = llm.merge_chat_history
original_validate = llm.validate_provider_env
llm.initialize_llm = dummy_initialize_llm
llm.merge_chat_history = dummy_merge_chat_history
llm.validate_provider_env = dummy_validate_provider_env
self.fallback_handler.tool_failure_consecutive_failures = 2
self.fallback_handler.attempt_fallback("dummy_tool_call()", self.logger, self.agent)
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
llm.initialize_llm = original_initialize
llm.merge_chat_history = original_merge
llm.validate_provider_env = original_validate
if __name__ == "__main__":
unittest.main()