diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index cdeac25..4dcc6a5 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -806,6 +806,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: logger.debug("Agent output: %s", chunk) check_interrupt() print_agent_output(chunk) + if _global_memory["plan_completed"]: _global_memory["plan_completed"] = False _global_memory["task_completed"] = False diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index a651ab8..57c7467 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Union from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage -from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES +from ra_aid.fallback_handler import FallbackHandler from ra_aid.exceptions import ToolExecutionError from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT @@ -90,8 +90,7 @@ class CiaynAgent: config = {} self.config = config self.provider = config.get("provider", "openai") - self.fallback_enabled = config.get("fallback_tool_enabled", True) - self.fallback_tool_models = self._load_fallback_tool_models(config) + self.fallback_handler = FallbackHandler(config) self.model = model self.tools = tools @@ -100,18 +99,8 @@ class CiaynAgent: self.available_functions = [] for t in tools: self.available_functions.append(get_function_info(t.func)) - self.tool_failure_consecutive_failures = 0 self.tool_failure_current_provider = None self.tool_failure_current_model = None - self.tool_failure_used_fallbacks = set() - - def _load_fallback_tool_models(self, config: dict) -> list: - fallback_tool_models_config = config.get("fallback_tool_models") - if fallback_tool_models_config: - return [m.strip() for m in fallback_tool_models_config.split(",") if m.strip()] - else: - from ra_aid.tool_leaderboard import supported_top_tool_models - return [item["model"] for item in supported_top_tool_models[:5]] def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -262,7 +251,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" logger.debug( f"_execute_tool: tool executed successfully with result: {result}" ) - self.tool_failure_consecutive_failures = 0 + self.fallback_handler.reset_fallback_handler() return result except Exception as e: logger.debug(f"_execute_tool: exception caught: {e}") @@ -275,61 +264,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" ) def _handle_tool_failure(self, code: str, error: Exception) -> None: - logger.debug( - f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}" - ) - self.tool_failure_consecutive_failures += 1 - max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) - logger.debug( - f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}" - ) - if ( - self.fallback_enabled - and self.tool_failure_consecutive_failures >= max_failures - and self.fallback_tool_models - ): - logger.debug( - "_handle_tool_failure: threshold reached, invoking fallback mechanism." - ) - self._attempt_fallback(code) - - def _attempt_fallback(self, code: str) -> None: - logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") - new_model = self.fallback_tool_models[0] - failed_tool_call_name = code.split("(")[0].strip() - logger.error( - f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}" - ) - try: - from ra_aid.llm import ( - initialize_llm, - merge_chat_history, - validate_provider_env, - ) - - logger.debug(f"_attempt_fallback: validating provider {self.provider}") - if not validate_provider_env(self.provider): - logger.error( - f"Missing environment configuration for provider {self.provider}. Cannot fallback." - ) - else: - logger.debug( - f"_attempt_fallback: initializing fallback model {new_model}" - ) - self.model = initialize_llm(self.provider, new_model) - logger.debug( - f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}" - ) - self.model.bind_tools(self.tools, tool_choice=failed_tool_call_name) - self.tool_failure_used_fallbacks.add(new_model) - logger.debug("_attempt_fallback: merging chat history for fallback") - merge_chat_history() - self.tool_failure_consecutive_failures = 0 - logger.debug( - "_attempt_fallback: fallback successful and tool failure counter reset" - ) - except Exception as switch_e: - logger.error(f"Fallback model switching failed: {switch_e}") + self.fallback_handler.handle_failure(code, error, logger, self) def _create_agent_chunk(self, content: str) -> Dict[str, Any]: """Create an agent chunk in the format expected by print_agent_output.""" diff --git a/ra_aid/config.py b/ra_aid/config.py index 41868dd..4c9bfea 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -3,6 +3,7 @@ DEFAULT_RECURSION_LIMIT = 100 DEFAULT_MAX_TEST_CMD_RETRIES = 3 DEFAULT_MAX_TOOL_FAILURES = 3 +FALLBACK_TOOL_MODEL_LIMIT = 5 VALID_PROVIDERS = [ "anthropic", diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py new file mode 100644 index 0000000..0eef7b6 --- /dev/null +++ b/ra_aid/fallback_handler.py @@ -0,0 +1,99 @@ +from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT +from ra_aid.tool_leaderboard import supported_top_tool_models +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env + + +class FallbackHandler: + def __init__(self, config): + self.config = config + self.fallback_enabled = config.get("fallback_tool_enabled", True) + self.fallback_tool_models = self._load_fallback_tool_models(config) + self.tool_failure_consecutive_failures = 0 + self.tool_failure_used_fallbacks = set() + + def _load_fallback_tool_models(self, config): + fallback_tool_models_config = config.get("fallback_tool_models") + if fallback_tool_models_config: + return [ + m.strip() for m in fallback_tool_models_config.split(",") if m.strip() + ] + else: + console = Console() + supported = [] + skipped = [] + for item in supported_top_tool_models: + provider = item.get("provider") + model_name = item.get("model") + if validate_provider_env(provider): + supported.append(model_name) + if len(supported) == FALLBACK_TOOL_MODEL_LIMIT: + break + else: + skipped.append(model_name) + final_models = supported[:FALLBACK_TOOL_MODEL_LIMIT] + message = "Fallback models selected: " + ", ".join(final_models) + if skipped: + message += ( + "\nSkipped top tool calling models due to missing provider ENV API keys: " + + ", ".join(skipped) + ) + console.print(Panel(Markdown(message), title="Fallback Models")) + return final_models + + def handle_failure(self, code: str, error: Exception, logger, agent): + logger.debug( + f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}" + ) + self.tool_failure_consecutive_failures += 1 + max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) + logger.debug( + f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}" + ) + if ( + self.fallback_enabled + and self.tool_failure_consecutive_failures >= max_failures + and self.fallback_tool_models + ): + logger.debug( + "_handle_tool_failure: threshold reached, invoking fallback mechanism." + ) + self.attempt_fallback(code, logger, agent) + + def attempt_fallback(self, code: str, logger, agent): + logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") + new_model = self.fallback_tool_models[0] + failed_tool_call_name = code.split("(")[0].strip() + logger.error( + f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}" + ) + try: + logger.debug(f"_attempt_fallback: validating provider {agent.provider}") + if not validate_provider_env(agent.provider): + logger.error( + f"Missing environment configuration for provider {agent.provider}. Cannot fallback." + ) + else: + logger.debug( + f"_attempt_fallback: initializing fallback model {new_model}" + ) + agent.model = initialize_llm(agent.provider, new_model) + logger.debug( + f"_attempt_fallback: binding tools to new model using tool: {failed_tool_call_name}" + ) + agent.model.bind_tools(agent.tools, tool_choice=failed_tool_call_name) + self.tool_failure_used_fallbacks.add(new_model) + logger.debug("_attempt_fallback: merging chat history for fallback") + merge_chat_history() + self.tool_failure_consecutive_failures = 0 + logger.debug( + "_attempt_fallback: fallback successful and tool failure counter reset" + ) + except Exception as switch_e: + logger.error(f"Fallback model switching failed: {switch_e}") + + def reset_fallback_handler(self): + self.tool_failure_consecutive_failures = 0 + self.tool_failure_used_fallbacks.clear() diff --git a/tests/ra_aid/test_ciayn_agent.py b/tests/ra_aid/test_ciayn_agent.py index 2e39dfa..46c8191 100644 --- a/tests/ra_aid/test_ciayn_agent.py +++ b/tests/ra_aid/test_ciayn_agent.py @@ -264,61 +264,7 @@ class TestFunctionCallValidation: class TestCiaynAgentNewMethods(unittest.TestCase): - def setUp(self): - # Create a dummy tool that always fails for testing fallback - def always_fail(): - raise Exception("Failure for fallback test") - - self.always_fail_tool = DummyTool(always_fail) - # Create a dummy model that does minimal work for fallback tests - self.dummy_model = DummyModel() - # Initialize CiaynAgent with configuration to trigger fallback quickly - self.agent = CiaynAgent( - self.dummy_model, - [self.always_fail_tool], - config={ - "max_tool_failures": 2, - "fallback_tool_models": "dummy-fallback-model", - }, - ) - - def test_handle_tool_failure_increments_counter(self): - initial_failures = self.agent.tool_failure_consecutive_failures - self.agent._handle_tool_failure("dummy_call()", Exception("Test error")) - self.assertEqual( - self.agent.tool_failure_consecutive_failures, initial_failures + 1 - ) - - def test_attempt_fallback_invokes_fallback_logic(self): - # Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env - # to simulate fallback switching without external dependencies. - def dummy_initialize_llm(provider, model_name, temperature=None): - return self.dummy_model - - def dummy_merge_chat_history(): - return ["merged"] - - def dummy_validate_provider_env(provider): - return True - - import ra_aid.llm as llm - - original_initialize = llm.initialize_llm - original_merge = llm.merge_chat_history - original_validate = llm.validate_provider_env - llm.initialize_llm = dummy_initialize_llm - llm.merge_chat_history = dummy_merge_chat_history - llm.validate_provider_env = dummy_validate_provider_env - - # Set failure counter high enough to trigger fallback in _handle_tool_failure - self.agent.tool_failure_consecutive_failures = 2 - # Call _attempt_fallback; it should reset the failure counter to 0 on success. - self.agent._attempt_fallback("always_fail_tool()") - self.assertEqual(self.agent.tool_failure_consecutive_failures, 0) - # Restore original functions - llm.initialize_llm = original_initialize - llm.merge_chat_history = original_merge - llm.validate_provider_env = original_validate + pass if __name__ == "__main__": diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py new file mode 100644 index 0000000..6f0c285 --- /dev/null +++ b/tests/ra_aid/test_fallback_handler.py @@ -0,0 +1,58 @@ +import unittest +from ra_aid.fallback_handler import FallbackHandler + +class DummyLogger: + def debug(self, msg): + pass + def error(self, msg): + pass + +class DummyAgent: + provider = "openai" + tools = [] + model = None + +class TestFallbackHandler(unittest.TestCase): + def setUp(self): + self.config = {"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"} + self.fallback_handler = FallbackHandler(self.config) + self.logger = DummyLogger() + self.agent = DummyAgent() + + def test_handle_failure_increments_counter(self): + initial_failures = self.fallback_handler.tool_failure_consecutive_failures + self.fallback_handler.handle_failure("dummy_call()", Exception("Test error"), self.logger, self.agent) + self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1) + + def test_attempt_fallback_resets_counter(self): + # Monkey-patch dummy functions for fallback components + def dummy_initialize_llm(provider, model_name, temperature=None): + class DummyModel: + def bind_tools(self, tools, tool_choice): + pass + return DummyModel() + + def dummy_merge_chat_history(): + return ["merged"] + + def dummy_validate_provider_env(provider): + return True + + import ra_aid.llm as llm + original_initialize = llm.initialize_llm + original_merge = llm.merge_chat_history + original_validate = llm.validate_provider_env + llm.initialize_llm = dummy_initialize_llm + llm.merge_chat_history = dummy_merge_chat_history + llm.validate_provider_env = dummy_validate_provider_env + + self.fallback_handler.tool_failure_consecutive_failures = 2 + self.fallback_handler.attempt_fallback("dummy_tool_call()", self.logger, self.agent) + self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0) + + llm.initialize_llm = original_initialize + llm.merge_chat_history = original_merge + llm.validate_provider_env = original_validate + +if __name__ == "__main__": + unittest.main()