feat(fallback_handler): implement FallbackHandler class to manage tool failures and fallback logic
refactor(ciayn_agent): integrate FallbackHandler into CiaynAgent for improved failure handling fix(agent_utils): add missing newline for better readability in run_agent_with_retry function test(fallback_handler): add unit tests for FallbackHandler to ensure correct failure handling and fallback logic
This commit is contained in:
parent
d8ee4e04f4
commit
55abf6e5dd
|
|
@ -806,6 +806,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
|||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
|
||||
if _global_memory["plan_completed"]:
|
||||
_global_memory["plan_completed"] = False
|
||||
_global_memory["task_completed"] = False
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
|||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||
|
|
@ -90,8 +90,7 @@ class CiaynAgent:
|
|||
config = {}
|
||||
self.config = config
|
||||
self.provider = config.get("provider", "openai")
|
||||
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
||||
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
||||
self.fallback_handler = FallbackHandler(config)
|
||||
|
||||
self.model = model
|
||||
self.tools = tools
|
||||
|
|
@ -100,18 +99,8 @@ class CiaynAgent:
|
|||
self.available_functions = []
|
||||
for t in tools:
|
||||
self.available_functions.append(get_function_info(t.func))
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.tool_failure_current_provider = None
|
||||
self.tool_failure_current_model = None
|
||||
self.tool_failure_used_fallbacks = set()
|
||||
|
||||
def _load_fallback_tool_models(self, config: dict) -> list:
|
||||
fallback_tool_models_config = config.get("fallback_tool_models")
|
||||
if fallback_tool_models_config:
|
||||
return [m.strip() for m in fallback_tool_models_config.split(",") if m.strip()]
|
||||
else:
|
||||
from ra_aid.tool_leaderboard import supported_top_tool_models
|
||||
return [item["model"] for item in supported_top_tool_models[:5]]
|
||||
|
||||
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
||||
"""Build the prompt for the agent including available tools and context."""
|
||||
|
|
@ -262,7 +251,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
logger.debug(
|
||||
f"_execute_tool: tool executed successfully with result: {result}"
|
||||
)
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.fallback_handler.reset_fallback_handler()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"_execute_tool: exception caught: {e}")
|
||||
|
|
@ -275,61 +264,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
)
|
||||
|
||||
def _handle_tool_failure(self, code: str, error: Exception) -> None:
|
||||
logger.debug(
|
||||
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
|
||||
)
|
||||
self.tool_failure_consecutive_failures += 1
|
||||
max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
||||
logger.debug(
|
||||
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}"
|
||||
)
|
||||
if (
|
||||
self.fallback_enabled
|
||||
and self.tool_failure_consecutive_failures >= max_failures
|
||||
and self.fallback_tool_models
|
||||
):
|
||||
logger.debug(
|
||||
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
|
||||
)
|
||||
self._attempt_fallback(code)
|
||||
|
||||
def _attempt_fallback(self, code: str) -> None:
|
||||
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
|
||||
new_model = self.fallback_tool_models[0]
|
||||
failed_tool_call_name = code.split("(")[0].strip()
|
||||
logger.error(
|
||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}"
|
||||
)
|
||||
try:
|
||||
from ra_aid.llm import (
|
||||
initialize_llm,
|
||||
merge_chat_history,
|
||||
validate_provider_env,
|
||||
)
|
||||
|
||||
logger.debug(f"_attempt_fallback: validating provider {self.provider}")
|
||||
if not validate_provider_env(self.provider):
|
||||
logger.error(
|
||||
f"Missing environment configuration for provider {self.provider}. Cannot fallback."
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"_attempt_fallback: initializing fallback model {new_model}"
|
||||
)
|
||||
self.model = initialize_llm(self.provider, new_model)
|
||||
logger.debug(
|
||||
f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}"
|
||||
)
|
||||
self.model.bind_tools(self.tools, tool_choice=failed_tool_call_name)
|
||||
self.tool_failure_used_fallbacks.add(new_model)
|
||||
logger.debug("_attempt_fallback: merging chat history for fallback")
|
||||
merge_chat_history()
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
logger.debug(
|
||||
"_attempt_fallback: fallback successful and tool failure counter reset"
|
||||
)
|
||||
except Exception as switch_e:
|
||||
logger.error(f"Fallback model switching failed: {switch_e}")
|
||||
self.fallback_handler.handle_failure(code, error, logger, self)
|
||||
|
||||
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
||||
"""Create an agent chunk in the format expected by print_agent_output."""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
DEFAULT_RECURSION_LIMIT = 100
|
||||
DEFAULT_MAX_TEST_CMD_RETRIES = 3
|
||||
DEFAULT_MAX_TOOL_FAILURES = 3
|
||||
FALLBACK_TOOL_MODEL_LIMIT = 5
|
||||
|
||||
VALID_PROVIDERS = [
|
||||
"anthropic",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,99 @@
|
|||
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT
|
||||
from ra_aid.tool_leaderboard import supported_top_tool_models
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
|
||||
|
||||
|
||||
class FallbackHandler:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
||||
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.tool_failure_used_fallbacks = set()
|
||||
|
||||
def _load_fallback_tool_models(self, config):
|
||||
fallback_tool_models_config = config.get("fallback_tool_models")
|
||||
if fallback_tool_models_config:
|
||||
return [
|
||||
m.strip() for m in fallback_tool_models_config.split(",") if m.strip()
|
||||
]
|
||||
else:
|
||||
console = Console()
|
||||
supported = []
|
||||
skipped = []
|
||||
for item in supported_top_tool_models:
|
||||
provider = item.get("provider")
|
||||
model_name = item.get("model")
|
||||
if validate_provider_env(provider):
|
||||
supported.append(model_name)
|
||||
if len(supported) == FALLBACK_TOOL_MODEL_LIMIT:
|
||||
break
|
||||
else:
|
||||
skipped.append(model_name)
|
||||
final_models = supported[:FALLBACK_TOOL_MODEL_LIMIT]
|
||||
message = "Fallback models selected: " + ", ".join(final_models)
|
||||
if skipped:
|
||||
message += (
|
||||
"\nSkipped top tool calling models due to missing provider ENV API keys: "
|
||||
+ ", ".join(skipped)
|
||||
)
|
||||
console.print(Panel(Markdown(message), title="Fallback Models"))
|
||||
return final_models
|
||||
|
||||
def handle_failure(self, code: str, error: Exception, logger, agent):
|
||||
logger.debug(
|
||||
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
|
||||
)
|
||||
self.tool_failure_consecutive_failures += 1
|
||||
max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
||||
logger.debug(
|
||||
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}"
|
||||
)
|
||||
if (
|
||||
self.fallback_enabled
|
||||
and self.tool_failure_consecutive_failures >= max_failures
|
||||
and self.fallback_tool_models
|
||||
):
|
||||
logger.debug(
|
||||
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
|
||||
)
|
||||
self.attempt_fallback(code, logger, agent)
|
||||
|
||||
def attempt_fallback(self, code: str, logger, agent):
|
||||
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
|
||||
new_model = self.fallback_tool_models[0]
|
||||
failed_tool_call_name = code.split("(")[0].strip()
|
||||
logger.error(
|
||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}"
|
||||
)
|
||||
try:
|
||||
logger.debug(f"_attempt_fallback: validating provider {agent.provider}")
|
||||
if not validate_provider_env(agent.provider):
|
||||
logger.error(
|
||||
f"Missing environment configuration for provider {agent.provider}. Cannot fallback."
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"_attempt_fallback: initializing fallback model {new_model}"
|
||||
)
|
||||
agent.model = initialize_llm(agent.provider, new_model)
|
||||
logger.debug(
|
||||
f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}"
|
||||
)
|
||||
agent.model.bind_tools(agent.tools, tool_choice=failed_tool_call_name)
|
||||
self.tool_failure_used_fallbacks.add(new_model)
|
||||
logger.debug("_attempt_fallback: merging chat history for fallback")
|
||||
merge_chat_history()
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
logger.debug(
|
||||
"_attempt_fallback: fallback successful and tool failure counter reset"
|
||||
)
|
||||
except Exception as switch_e:
|
||||
logger.error(f"Fallback model switching failed: {switch_e}")
|
||||
|
||||
def reset_fallback_handler(self):
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.tool_failure_used_fallbacks.clear()
|
||||
|
|
@ -264,61 +264,7 @@ class TestFunctionCallValidation:
|
|||
|
||||
|
||||
class TestCiaynAgentNewMethods(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Create a dummy tool that always fails for testing fallback
|
||||
def always_fail():
|
||||
raise Exception("Failure for fallback test")
|
||||
|
||||
self.always_fail_tool = DummyTool(always_fail)
|
||||
# Create a dummy model that does minimal work for fallback tests
|
||||
self.dummy_model = DummyModel()
|
||||
# Initialize CiaynAgent with configuration to trigger fallback quickly
|
||||
self.agent = CiaynAgent(
|
||||
self.dummy_model,
|
||||
[self.always_fail_tool],
|
||||
config={
|
||||
"max_tool_failures": 2,
|
||||
"fallback_tool_models": "dummy-fallback-model",
|
||||
},
|
||||
)
|
||||
|
||||
def test_handle_tool_failure_increments_counter(self):
|
||||
initial_failures = self.agent.tool_failure_consecutive_failures
|
||||
self.agent._handle_tool_failure("dummy_call()", Exception("Test error"))
|
||||
self.assertEqual(
|
||||
self.agent.tool_failure_consecutive_failures, initial_failures + 1
|
||||
)
|
||||
|
||||
def test_attempt_fallback_invokes_fallback_logic(self):
|
||||
# Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env
|
||||
# to simulate fallback switching without external dependencies.
|
||||
def dummy_initialize_llm(provider, model_name, temperature=None):
|
||||
return self.dummy_model
|
||||
|
||||
def dummy_merge_chat_history():
|
||||
return ["merged"]
|
||||
|
||||
def dummy_validate_provider_env(provider):
|
||||
return True
|
||||
|
||||
import ra_aid.llm as llm
|
||||
|
||||
original_initialize = llm.initialize_llm
|
||||
original_merge = llm.merge_chat_history
|
||||
original_validate = llm.validate_provider_env
|
||||
llm.initialize_llm = dummy_initialize_llm
|
||||
llm.merge_chat_history = dummy_merge_chat_history
|
||||
llm.validate_provider_env = dummy_validate_provider_env
|
||||
|
||||
# Set failure counter high enough to trigger fallback in _handle_tool_failure
|
||||
self.agent.tool_failure_consecutive_failures = 2
|
||||
# Call _attempt_fallback; it should reset the failure counter to 0 on success.
|
||||
self.agent._attempt_fallback("always_fail_tool()")
|
||||
self.assertEqual(self.agent.tool_failure_consecutive_failures, 0)
|
||||
# Restore original functions
|
||||
llm.initialize_llm = original_initialize
|
||||
llm.merge_chat_history = original_merge
|
||||
llm.validate_provider_env = original_validate
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
import unittest
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
|
||||
class DummyLogger:
|
||||
def debug(self, msg):
|
||||
pass
|
||||
def error(self, msg):
|
||||
pass
|
||||
|
||||
class DummyAgent:
|
||||
provider = "openai"
|
||||
tools = []
|
||||
model = None
|
||||
|
||||
class TestFallbackHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = {"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"}
|
||||
self.fallback_handler = FallbackHandler(self.config)
|
||||
self.logger = DummyLogger()
|
||||
self.agent = DummyAgent()
|
||||
|
||||
def test_handle_failure_increments_counter(self):
|
||||
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
|
||||
self.fallback_handler.handle_failure("dummy_call()", Exception("Test error"), self.logger, self.agent)
|
||||
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1)
|
||||
|
||||
def test_attempt_fallback_resets_counter(self):
|
||||
# Monkey-patch dummy functions for fallback components
|
||||
def dummy_initialize_llm(provider, model_name, temperature=None):
|
||||
class DummyModel:
|
||||
def bind_tools(self, tools, tool_choice):
|
||||
pass
|
||||
return DummyModel()
|
||||
|
||||
def dummy_merge_chat_history():
|
||||
return ["merged"]
|
||||
|
||||
def dummy_validate_provider_env(provider):
|
||||
return True
|
||||
|
||||
import ra_aid.llm as llm
|
||||
original_initialize = llm.initialize_llm
|
||||
original_merge = llm.merge_chat_history
|
||||
original_validate = llm.validate_provider_env
|
||||
llm.initialize_llm = dummy_initialize_llm
|
||||
llm.merge_chat_history = dummy_merge_chat_history
|
||||
llm.validate_provider_env = dummy_validate_provider_env
|
||||
|
||||
self.fallback_handler.tool_failure_consecutive_failures = 2
|
||||
self.fallback_handler.attempt_fallback("dummy_tool_call()", self.logger, self.agent)
|
||||
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
|
||||
|
||||
llm.initialize_llm = original_initialize
|
||||
llm.merge_chat_history = original_merge
|
||||
llm.validate_provider_env = original_validate
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue