feat(fallback_handler): implement FallbackHandler class to manage tool failures and fallback logic
refactor(ciayn_agent): integrate FallbackHandler into CiaynAgent for improved failure handling fix(agent_utils): add missing newline for better readability in run_agent_with_retry function test(fallback_handler): add unit tests for FallbackHandler to ensure correct failure handling and fallback logic
This commit is contained in:
parent
d8ee4e04f4
commit
55abf6e5dd
|
|
@ -806,6 +806,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||||
logger.debug("Agent output: %s", chunk)
|
logger.debug("Agent output: %s", chunk)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
print_agent_output(chunk)
|
print_agent_output(chunk)
|
||||||
|
|
||||||
if _global_memory["plan_completed"]:
|
if _global_memory["plan_completed"]:
|
||||||
_global_memory["plan_completed"] = False
|
_global_memory["plan_completed"] = False
|
||||||
_global_memory["task_completed"] = False
|
_global_memory["task_completed"] = False
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
from ra_aid.fallback_handler import FallbackHandler
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||||
|
|
@ -90,8 +90,7 @@ class CiaynAgent:
|
||||||
config = {}
|
config = {}
|
||||||
self.config = config
|
self.config = config
|
||||||
self.provider = config.get("provider", "openai")
|
self.provider = config.get("provider", "openai")
|
||||||
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
self.fallback_handler = FallbackHandler(config)
|
||||||
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
@ -100,18 +99,8 @@ class CiaynAgent:
|
||||||
self.available_functions = []
|
self.available_functions = []
|
||||||
for t in tools:
|
for t in tools:
|
||||||
self.available_functions.append(get_function_info(t.func))
|
self.available_functions.append(get_function_info(t.func))
|
||||||
self.tool_failure_consecutive_failures = 0
|
|
||||||
self.tool_failure_current_provider = None
|
self.tool_failure_current_provider = None
|
||||||
self.tool_failure_current_model = None
|
self.tool_failure_current_model = None
|
||||||
self.tool_failure_used_fallbacks = set()
|
|
||||||
|
|
||||||
def _load_fallback_tool_models(self, config: dict) -> list:
|
|
||||||
fallback_tool_models_config = config.get("fallback_tool_models")
|
|
||||||
if fallback_tool_models_config:
|
|
||||||
return [m.strip() for m in fallback_tool_models_config.split(",") if m.strip()]
|
|
||||||
else:
|
|
||||||
from ra_aid.tool_leaderboard import supported_top_tool_models
|
|
||||||
return [item["model"] for item in supported_top_tool_models[:5]]
|
|
||||||
|
|
||||||
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
||||||
"""Build the prompt for the agent including available tools and context."""
|
"""Build the prompt for the agent including available tools and context."""
|
||||||
|
|
@ -262,7 +251,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"_execute_tool: tool executed successfully with result: {result}"
|
f"_execute_tool: tool executed successfully with result: {result}"
|
||||||
)
|
)
|
||||||
self.tool_failure_consecutive_failures = 0
|
self.fallback_handler.reset_fallback_handler()
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"_execute_tool: exception caught: {e}")
|
logger.debug(f"_execute_tool: exception caught: {e}")
|
||||||
|
|
@ -275,61 +264,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_tool_failure(self, code: str, error: Exception) -> None:
|
def _handle_tool_failure(self, code: str, error: Exception) -> None:
|
||||||
logger.debug(
|
self.fallback_handler.handle_failure(code, error, logger, self)
|
||||||
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
|
|
||||||
)
|
|
||||||
self.tool_failure_consecutive_failures += 1
|
|
||||||
max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
|
||||||
logger.debug(
|
|
||||||
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
self.fallback_enabled
|
|
||||||
and self.tool_failure_consecutive_failures >= max_failures
|
|
||||||
and self.fallback_tool_models
|
|
||||||
):
|
|
||||||
logger.debug(
|
|
||||||
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
|
|
||||||
)
|
|
||||||
self._attempt_fallback(code)
|
|
||||||
|
|
||||||
def _attempt_fallback(self, code: str) -> None:
|
|
||||||
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
|
|
||||||
new_model = self.fallback_tool_models[0]
|
|
||||||
failed_tool_call_name = code.split("(")[0].strip()
|
|
||||||
logger.error(
|
|
||||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
from ra_aid.llm import (
|
|
||||||
initialize_llm,
|
|
||||||
merge_chat_history,
|
|
||||||
validate_provider_env,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"_attempt_fallback: validating provider {self.provider}")
|
|
||||||
if not validate_provider_env(self.provider):
|
|
||||||
logger.error(
|
|
||||||
f"Missing environment configuration for provider {self.provider}. Cannot fallback."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
f"_attempt_fallback: initializing fallback model {new_model}"
|
|
||||||
)
|
|
||||||
self.model = initialize_llm(self.provider, new_model)
|
|
||||||
logger.debug(
|
|
||||||
f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}"
|
|
||||||
)
|
|
||||||
self.model.bind_tools(self.tools, tool_choice=failed_tool_call_name)
|
|
||||||
self.tool_failure_used_fallbacks.add(new_model)
|
|
||||||
logger.debug("_attempt_fallback: merging chat history for fallback")
|
|
||||||
merge_chat_history()
|
|
||||||
self.tool_failure_consecutive_failures = 0
|
|
||||||
logger.debug(
|
|
||||||
"_attempt_fallback: fallback successful and tool failure counter reset"
|
|
||||||
)
|
|
||||||
except Exception as switch_e:
|
|
||||||
logger.error(f"Fallback model switching failed: {switch_e}")
|
|
||||||
|
|
||||||
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
||||||
"""Create an agent chunk in the format expected by print_agent_output."""
|
"""Create an agent chunk in the format expected by print_agent_output."""
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
DEFAULT_RECURSION_LIMIT = 100
|
DEFAULT_RECURSION_LIMIT = 100
|
||||||
DEFAULT_MAX_TEST_CMD_RETRIES = 3
|
DEFAULT_MAX_TEST_CMD_RETRIES = 3
|
||||||
DEFAULT_MAX_TOOL_FAILURES = 3
|
DEFAULT_MAX_TOOL_FAILURES = 3
|
||||||
|
FALLBACK_TOOL_MODEL_LIMIT = 5
|
||||||
|
|
||||||
VALID_PROVIDERS = [
|
VALID_PROVIDERS = [
|
||||||
"anthropic",
|
"anthropic",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,99 @@
|
||||||
|
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT
|
||||||
|
from ra_aid.tool_leaderboard import supported_top_tool_models
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
|
||||||
|
|
||||||
|
|
||||||
|
class FallbackHandler:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
||||||
|
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
||||||
|
self.tool_failure_consecutive_failures = 0
|
||||||
|
self.tool_failure_used_fallbacks = set()
|
||||||
|
|
||||||
|
def _load_fallback_tool_models(self, config):
|
||||||
|
fallback_tool_models_config = config.get("fallback_tool_models")
|
||||||
|
if fallback_tool_models_config:
|
||||||
|
return [
|
||||||
|
m.strip() for m in fallback_tool_models_config.split(",") if m.strip()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
console = Console()
|
||||||
|
supported = []
|
||||||
|
skipped = []
|
||||||
|
for item in supported_top_tool_models:
|
||||||
|
provider = item.get("provider")
|
||||||
|
model_name = item.get("model")
|
||||||
|
if validate_provider_env(provider):
|
||||||
|
supported.append(model_name)
|
||||||
|
if len(supported) == FALLBACK_TOOL_MODEL_LIMIT:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
skipped.append(model_name)
|
||||||
|
final_models = supported[:FALLBACK_TOOL_MODEL_LIMIT]
|
||||||
|
message = "Fallback models selected: " + ", ".join(final_models)
|
||||||
|
if skipped:
|
||||||
|
message += (
|
||||||
|
"\nSkipped top tool calling models due to missing provider ENV API keys: "
|
||||||
|
+ ", ".join(skipped)
|
||||||
|
)
|
||||||
|
console.print(Panel(Markdown(message), title="Fallback Models"))
|
||||||
|
return final_models
|
||||||
|
|
||||||
|
def handle_failure(self, code: str, error: Exception, logger, agent):
|
||||||
|
logger.debug(
|
||||||
|
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
|
||||||
|
)
|
||||||
|
self.tool_failure_consecutive_failures += 1
|
||||||
|
max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
||||||
|
logger.debug(
|
||||||
|
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.fallback_enabled
|
||||||
|
and self.tool_failure_consecutive_failures >= max_failures
|
||||||
|
and self.fallback_tool_models
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
|
||||||
|
)
|
||||||
|
self.attempt_fallback(code, logger, agent)
|
||||||
|
|
||||||
|
def attempt_fallback(self, code: str, logger, agent):
|
||||||
|
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
|
||||||
|
new_model = self.fallback_tool_models[0]
|
||||||
|
failed_tool_call_name = code.split("(")[0].strip()
|
||||||
|
logger.error(
|
||||||
|
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
logger.debug(f"_attempt_fallback: validating provider {agent.provider}")
|
||||||
|
if not validate_provider_env(agent.provider):
|
||||||
|
logger.error(
|
||||||
|
f"Missing environment configuration for provider {agent.provider}. Cannot fallback."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"_attempt_fallback: initializing fallback model {new_model}"
|
||||||
|
)
|
||||||
|
agent.model = initialize_llm(agent.provider, new_model)
|
||||||
|
logger.debug(
|
||||||
|
f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}"
|
||||||
|
)
|
||||||
|
agent.model.bind_tools(agent.tools, tool_choice=failed_tool_call_name)
|
||||||
|
self.tool_failure_used_fallbacks.add(new_model)
|
||||||
|
logger.debug("_attempt_fallback: merging chat history for fallback")
|
||||||
|
merge_chat_history()
|
||||||
|
self.tool_failure_consecutive_failures = 0
|
||||||
|
logger.debug(
|
||||||
|
"_attempt_fallback: fallback successful and tool failure counter reset"
|
||||||
|
)
|
||||||
|
except Exception as switch_e:
|
||||||
|
logger.error(f"Fallback model switching failed: {switch_e}")
|
||||||
|
|
||||||
|
def reset_fallback_handler(self):
|
||||||
|
self.tool_failure_consecutive_failures = 0
|
||||||
|
self.tool_failure_used_fallbacks.clear()
|
||||||
|
|
@ -264,61 +264,7 @@ class TestFunctionCallValidation:
|
||||||
|
|
||||||
|
|
||||||
class TestCiaynAgentNewMethods(unittest.TestCase):
|
class TestCiaynAgentNewMethods(unittest.TestCase):
|
||||||
def setUp(self):
|
pass
|
||||||
# Create a dummy tool that always fails for testing fallback
|
|
||||||
def always_fail():
|
|
||||||
raise Exception("Failure for fallback test")
|
|
||||||
|
|
||||||
self.always_fail_tool = DummyTool(always_fail)
|
|
||||||
# Create a dummy model that does minimal work for fallback tests
|
|
||||||
self.dummy_model = DummyModel()
|
|
||||||
# Initialize CiaynAgent with configuration to trigger fallback quickly
|
|
||||||
self.agent = CiaynAgent(
|
|
||||||
self.dummy_model,
|
|
||||||
[self.always_fail_tool],
|
|
||||||
config={
|
|
||||||
"max_tool_failures": 2,
|
|
||||||
"fallback_tool_models": "dummy-fallback-model",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_handle_tool_failure_increments_counter(self):
|
|
||||||
initial_failures = self.agent.tool_failure_consecutive_failures
|
|
||||||
self.agent._handle_tool_failure("dummy_call()", Exception("Test error"))
|
|
||||||
self.assertEqual(
|
|
||||||
self.agent.tool_failure_consecutive_failures, initial_failures + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_attempt_fallback_invokes_fallback_logic(self):
|
|
||||||
# Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env
|
|
||||||
# to simulate fallback switching without external dependencies.
|
|
||||||
def dummy_initialize_llm(provider, model_name, temperature=None):
|
|
||||||
return self.dummy_model
|
|
||||||
|
|
||||||
def dummy_merge_chat_history():
|
|
||||||
return ["merged"]
|
|
||||||
|
|
||||||
def dummy_validate_provider_env(provider):
|
|
||||||
return True
|
|
||||||
|
|
||||||
import ra_aid.llm as llm
|
|
||||||
|
|
||||||
original_initialize = llm.initialize_llm
|
|
||||||
original_merge = llm.merge_chat_history
|
|
||||||
original_validate = llm.validate_provider_env
|
|
||||||
llm.initialize_llm = dummy_initialize_llm
|
|
||||||
llm.merge_chat_history = dummy_merge_chat_history
|
|
||||||
llm.validate_provider_env = dummy_validate_provider_env
|
|
||||||
|
|
||||||
# Set failure counter high enough to trigger fallback in _handle_tool_failure
|
|
||||||
self.agent.tool_failure_consecutive_failures = 2
|
|
||||||
# Call _attempt_fallback; it should reset the failure counter to 0 on success.
|
|
||||||
self.agent._attempt_fallback("always_fail_tool()")
|
|
||||||
self.assertEqual(self.agent.tool_failure_consecutive_failures, 0)
|
|
||||||
# Restore original functions
|
|
||||||
llm.initialize_llm = original_initialize
|
|
||||||
llm.merge_chat_history = original_merge
|
|
||||||
llm.validate_provider_env = original_validate
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
import unittest
|
||||||
|
from ra_aid.fallback_handler import FallbackHandler
|
||||||
|
|
||||||
|
class DummyLogger:
|
||||||
|
def debug(self, msg):
|
||||||
|
pass
|
||||||
|
def error(self, msg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DummyAgent:
|
||||||
|
provider = "openai"
|
||||||
|
tools = []
|
||||||
|
model = None
|
||||||
|
|
||||||
|
class TestFallbackHandler(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.config = {"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"}
|
||||||
|
self.fallback_handler = FallbackHandler(self.config)
|
||||||
|
self.logger = DummyLogger()
|
||||||
|
self.agent = DummyAgent()
|
||||||
|
|
||||||
|
def test_handle_failure_increments_counter(self):
|
||||||
|
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
|
||||||
|
self.fallback_handler.handle_failure("dummy_call()", Exception("Test error"), self.logger, self.agent)
|
||||||
|
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1)
|
||||||
|
|
||||||
|
def test_attempt_fallback_resets_counter(self):
|
||||||
|
# Monkey-patch dummy functions for fallback components
|
||||||
|
def dummy_initialize_llm(provider, model_name, temperature=None):
|
||||||
|
class DummyModel:
|
||||||
|
def bind_tools(self, tools, tool_choice):
|
||||||
|
pass
|
||||||
|
return DummyModel()
|
||||||
|
|
||||||
|
def dummy_merge_chat_history():
|
||||||
|
return ["merged"]
|
||||||
|
|
||||||
|
def dummy_validate_provider_env(provider):
|
||||||
|
return True
|
||||||
|
|
||||||
|
import ra_aid.llm as llm
|
||||||
|
original_initialize = llm.initialize_llm
|
||||||
|
original_merge = llm.merge_chat_history
|
||||||
|
original_validate = llm.validate_provider_env
|
||||||
|
llm.initialize_llm = dummy_initialize_llm
|
||||||
|
llm.merge_chat_history = dummy_merge_chat_history
|
||||||
|
llm.validate_provider_env = dummy_validate_provider_env
|
||||||
|
|
||||||
|
self.fallback_handler.tool_failure_consecutive_failures = 2
|
||||||
|
self.fallback_handler.attempt_fallback("dummy_tool_call()", self.logger, self.agent)
|
||||||
|
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
|
||||||
|
|
||||||
|
llm.initialize_llm = original_initialize
|
||||||
|
llm.merge_chat_history = original_merge
|
||||||
|
llm.validate_provider_env = original_validate
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue