refactor(agent_utils.py): remove the _handle_tool_execution_error function and simplify error handling in run_agent_with_retry

feat(fallback_handler.py): enhance handle_failure method to extract tool name from ToolExecutionError and improve fallback logic
fix(exceptions.py): update ToolExecutionError to include base_message for better error context
feat(output.py): add base_message to ToolExecutionError for improved debugging
chore(tool_configs.py): update get_all_tools function to specify return type
style(logging_config.py): reorder imports for consistency
test(tests): add tests for new error handling and fallback logic in agent_utils and fallback_handler
This commit is contained in:
Ariel Frischer 2025-02-12 13:07:12 -08:00
parent a7322eaef2
commit af9f95ceb1
8 changed files with 252 additions and 208 deletions

View File

@ -9,7 +9,6 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Sequence from typing import Any, Dict, List, Literal, Optional, Sequence
from langgraph.graph.graph import CompiledGraph
import litellm import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
@ -20,6 +19,7 @@ from langchain_core.messages import (
) )
from langchain_core.tools import tool from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt.chat_agent_executor import AgentState from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import get_model_info from litellm import get_model_info
@ -876,9 +876,7 @@ def run_agent_with_retry(
logger.debug("Agent run completed successfully") logger.debug("Agent run completed successfully")
return "Agent run completed successfully" return "Agent run completed successfully"
except ToolExecutionError as e: except ToolExecutionError as e:
fallback_response = _handle_tool_execution_error( fallback_response = fallback_handler.handle_failure(e, agent)
fallback_handler, agent, e
)
if fallback_response: if fallback_response:
prompt = original_prompt + "\n" + fallback_response prompt = original_prompt + "\n" + fallback_response
continue continue
@ -895,42 +893,3 @@ def run_agent_with_retry(
finally: finally:
_decrement_agent_depth() _decrement_agent_depth()
_restore_interrupt_handling(original_handler) _restore_interrupt_handling(original_handler)
def _handle_tool_execution_error(
fallback_handler: FallbackHandler,
agent: CiaynAgent | CompiledGraph,
error: ToolExecutionError,
):
logger.debug("Entering _handle_tool_execution_error with error: %s", error)
if error.tool_name:
failed_tool_call_name = error.tool_name
logger.debug(
"Extracted failed_tool_call_name from error.tool_name: %s",
failed_tool_call_name,
)
else:
import re
msg = str(error)
logger.debug("Error message: %s", msg)
match = re.search(r"name=['\"](\w+)['\"]", msg)
if match:
failed_tool_call_name = match.group(1)
logger.debug(
"Extracted failed_tool_call_name using regex: %s", failed_tool_call_name
)
else:
failed_tool_call_name = "Tool execution error"
logger.debug(
"Defaulting failed_tool_call_name to: %s", failed_tool_call_name
)
logger.debug(
"Calling fallback_handler.handle_failure with failed_tool_call_name: %s",
failed_tool_call_name,
)
fallback_response = fallback_handler.handle_failure(
failed_tool_call_name, error, agent
)
logger.debug("Fallback response received: %s", fallback_response)
return fallback_response

View File

@ -44,7 +44,9 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
) )
) )
tool_name = getattr(msg, "name", None) tool_name = getattr(msg, "name", None)
raise ToolExecutionError(err_msg, tool_name=tool_name) cpm(f"type(msg): {type(msg)}")
cpm(f"msg: {msg}")
raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg)
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None: def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:

View File

@ -1,5 +1,9 @@
"""Custom exceptions for RA.Aid.""" """Custom exceptions for RA.Aid."""
from typing import Optional
from langchain_core.messages import BaseMessage
class AgentInterrupt(Exception): class AgentInterrupt(Exception):
"""Exception raised when an agent's execution is interrupted. """Exception raised when an agent's execution is interrupted.
@ -17,6 +21,13 @@ class ToolExecutionError(Exception):
This exception is used to distinguish tool execution failures This exception is used to distinguish tool execution failures
from other types of errors in the agent system. from other types of errors in the agent system.
""" """
def __init__(self, message, tool_name=None):
def __init__(
self,
message: str,
base_message: Optional[BaseMessage] = None,
tool_name: Optional[str] = None,
):
super().__init__(message) super().__init__(message)
self.base_message = base_message
self.tool_name = tool_name self.tool_name = tool_name

View File

@ -1,20 +1,22 @@
import json
import re
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.graph.graph import CompiledGraph from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import BaseMessage from langgraph.graph.message import BaseMessage
from ra_aid.console.output import cpm
import json
from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.config import ( from ra_aid.config import (
DEFAULT_MAX_TOOL_FAILURES, DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT, FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT, RETRY_FALLBACK_COUNT,
) )
from ra_aid.logging_config import get_logger from ra_aid.console.output import cpm
from ra_aid.tool_leaderboard import supported_top_tool_models from ra_aid.exceptions import ToolExecutionError
from rich.console import Console
from ra_aid.llm import initialize_llm, validate_provider_env from ra_aid.llm import initialize_llm, validate_provider_env
from ra_aid.logging_config import get_logger
from ra_aid.tool_configs import get_all_tools
from ra_aid.tool_leaderboard import supported_top_tool_models
logger = get_logger(__name__) logger = get_logger(__name__)
@ -41,9 +43,12 @@ class FallbackHandler:
self.tools: list[BaseTool] = tools self.tools: list[BaseTool] = tools
self.fallback_enabled = config.get("fallback_tool_enabled", True) self.fallback_enabled = config.get("fallback_tool_enabled", True)
self.fallback_tool_models = self._load_fallback_tool_models(config) self.fallback_tool_models = self._load_fallback_tool_models(config)
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
self.tool_failure_consecutive_failures = 0 self.tool_failure_consecutive_failures = 0
self.failed_messages = set()
self.tool_failure_used_fallbacks = set() self.tool_failure_used_fallbacks = set()
self.console = Console() self.current_failing_tool_name = ""
self.current_tool_to_bind = None
def _load_fallback_tool_models(self, config): def _load_fallback_tool_models(self, config):
""" """
@ -87,66 +92,104 @@ class FallbackHandler:
cpm(message, title="Fallback Models") cpm(message, title="Fallback Models")
return final_models return final_models
def handle_failure( def handle_failure(self, error: Exception, agent: CiaynAgent | CompiledGraph):
self, code: str, error: Exception, agent: CiaynAgent | CompiledGraph
):
""" """
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
Args: Args:
code (str): The code that failed to execute. error (Exception): The exception raised during execution. If the exception has a 'base_message' attribute, that message is recorded.
error (Exception): The exception raised during execution.
logger: Logger instance for logging.
agent: The agent instance on which fallback may be executed. agent: The agent instance on which fallback may be executed.
""" """
logger.debug( if not self.fallback_enabled:
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}" return None
)
self.tool_failure_consecutive_failures += 1 failed_tool_call_name = self.extract_failed_tool_name(error)
max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
logger.debug(
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}"
)
if ( if (
self.fallback_enabled self.current_failing_tool_name
and self.tool_failure_consecutive_failures >= max_failures and failed_tool_call_name != self.current_failing_tool_name
):
logger.debug(
"New failing tool name identified. Resetting consecutive tool failures."
)
self.reset_fallback_handler()
logger.debug(
f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}"
)
self.current_failing_tool_name = failed_tool_call_name
self.current_tool_to_bind = self._find_tool_to_bind(
agent, failed_tool_call_name
)
if hasattr(error, "base_message") and error.base_message:
self.failed_messages.add(str(error.base_message))
self.tool_failure_consecutive_failures += 1
logger.debug(
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {self.max_failures}"
)
if (
self.tool_failure_consecutive_failures >= self.max_failures
and self.fallback_tool_models and self.fallback_tool_models
): ):
logger.debug( logger.debug(
"_handle_tool_failure: threshold reached, invoking fallback mechanism." "_handle_tool_failure: threshold reached, invoking fallback mechanism."
) )
return self.attempt_fallback(code, logger, agent) return self.attempt_fallback()
def attempt_fallback(self, code: str, logger, agent): def attempt_fallback(self):
""" """
Initiate the fallback process by selecting a fallback model and triggering the appropriate fallback method. Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method.
Args: Returns:
code (str): The tool code that triggered the fallback. The response from a fallback model if any, otherwise None.
logger: Logger instance for logging messages.
agent: The agent for which fallback is being executed.
""" """
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
fallback_model = self.fallback_tool_models[0]
failed_tool_call_name = code
logger.error( logger.error(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
) )
cpm( cpm(
f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}.", f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
title="Fallback Notification", title="Fallback Notification",
) )
if fallback_model.get("type", "prompt").lower() == "fc": for fallback_model in self.fallback_tool_models:
self.attempt_fallback_function(code, logger, agent) if fallback_model.get("type", "prompt").lower() == "fc":
else: response = self.attempt_fallback_function(fallback_model)
self.attempt_fallback_prompt(code, logger, agent) else:
response = self.attempt_fallback_prompt(fallback_model)
if response:
return response
cpm("All fallback models have failed", title="Fallback Failed")
return None
def reset_fallback_handler(self): def reset_fallback_handler(self):
""" """
Reset the fallback handler's internal failure counters and clear the record of used fallback models. Reset the fallback handler's internal failure counters and clear the record of used fallback models.
""" """
self.tool_failure_consecutive_failures = 0 self.tool_failure_consecutive_failures = 0
self.failed_messages.clear()
self.tool_failure_used_fallbacks.clear() self.tool_failure_used_fallbacks.clear()
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
def extract_failed_tool_name(self, error: ToolExecutionError):
if error.tool_name:
failed_tool_call_name = error.tool_name
else:
msg = str(error)
logger.debug("Error message: %s", msg)
match = re.search(r"name=['\"](\w+)['\"]", msg)
if match:
failed_tool_call_name = str(match.group(1))
logger.debug(
"Extracted failed_tool_call_name using regex: %s",
failed_tool_call_name,
)
else:
failed_tool_call_name = "Tool execution error"
raise Exception("Fallback failed: Could not extract failed tool name.")
return failed_tool_call_name
def _find_tool_to_bind(self, agent, failed_tool_call_name): def _find_tool_to_bind(self, agent, failed_tool_call_name):
logger.debug(f"failed_tool_call_name={failed_tool_call_name}") logger.debug(f"failed_tool_call_name={failed_tool_call_name}")
@ -157,135 +200,108 @@ class FallbackHandler:
None, None,
) )
if tool_to_bind is None: if tool_to_bind is None:
from ra_aid.tool_configs import get_all_tools
all_tools = get_all_tools() all_tools = get_all_tools()
tool_to_bind = next( tool_to_bind = next(
(t for t in all_tools if t.func.__name__ == failed_tool_call_name), (t for t in all_tools if t.func.__name__ == failed_tool_call_name),
None, None,
) )
if tool_to_bind is None: if tool_to_bind is None:
available = [t.func.__name__ for t in get_all_tools()] # TODO: Would be nice to try fuzzy match or levenstein str match to find closest correspond tool name
logger.debug( raise Exception(
f"Failed to find tool: {failed_tool_call_name}. Available tools: {available}" f"Fallback failed: {failed_tool_call_name} not found in all tools."
) )
raise Exception(f"Tool {failed_tool_call_name} not found in all tools.")
return tool_to_bind return tool_to_bind
def attempt_fallback_prompt(self, code: str, logger, agent): def attempt_fallback_prompt(self, fallback_model):
""" """
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model.
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
Args: Args:
code (str): The tool code to invoke via fallback. fallback_model (dict): The fallback model to use.
logger: Logger instance for logging messages.
agent: The agent instance to update with the new model upon success.
Returns: Returns:
The response from the fallback model invocation. The response from the fallback model invocation, or None if failed.
Raises:
Exception: If all prompt-based fallback models fail.
""" """
logger.debug("Attempting prompt-based fallback using fallback models") try:
failed_tool_call_name = code logger.debug(f"Trying fallback model: {fallback_model['model']}")
for fallback_model in self.fallback_tool_models: simple_model = initialize_llm(
try: fallback_model["provider"], fallback_model["model"]
logger.debug(f"Trying fallback model: {fallback_model['model']}") )
simple_model = initialize_llm( binded_model = simple_model.bind_tools(
fallback_model["provider"], fallback_model["model"] [self.current_tool_to_bind],
) tool_choice=self.current_failing_tool_name,
tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name) )
binded_model = simple_model.bind_tools( retry_model = binded_model.with_retry(
[tool_to_bind], tool_choice=failed_tool_call_name stop_after_attempt=RETRY_FALLBACK_COUNT
) )
# retry_model = binded_model.with_retry( response = retry_model.invoke(self.current_failing_tool_name)
# stop_after_attempt=RETRY_FALLBACK_COUNT cpm(f"response={response}")
# ) self.tool_failure_used_fallbacks.add(fallback_model["model"])
response = binded_model.invoke(code) tool_call = self.base_message_to_tool_call_dict(response)
cpm(f"response={response}") if tool_call:
result = self.invoke_prompt_tool_call(tool_call)
self.tool_failure_used_fallbacks.add(fallback_model["model"]) cpm(f"result={result}")
tool_call = self.base_message_to_tool_call_dict(response)
if tool_call:
result = self.invoke_prompt_tool_call(tool_call)
cpm(f"result={result}")
logger.debug(
"Prompt-based fallback executed successfully with model: "
+ fallback_model["model"]
)
self.reset_fallback_handler()
return result
else:
cpm(
response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"],
)
return response
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
)
raise Exception("All prompt-based fallback models failed")
def attempt_fallback_function(self, code: str, logger, agent):
"""
Attempt a function-calling fallback by iterating over fallback models and invoking the provided code.
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
Args:
code (str): The tool code to invoke via fallback.
logger: Logger instance for logging messages.
agent: The agent instance to update with the new model upon success.
Returns:
The response from the fallback model invocation.
Raises:
Exception: If all function-calling fallback models fail.
"""
logger.debug("Attempting function-calling fallback using fallback models")
failed_tool_call_name = code
for fallback_model in self.fallback_tool_models:
try:
logger.debug(f"Trying fallback model: {fallback_model['model']}")
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name)
binded_model = simple_model.bind_tools(
[tool_to_bind], tool_choice=failed_tool_call_name
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(code)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"])
self.reset_fallback_handler()
logger.debug( logger.debug(
"Function-calling fallback executed successfully with model: " "Prompt-based fallback executed successfully with model: "
+ fallback_model["model"] + fallback_model["model"]
) )
self.reset_fallback_handler()
return result
else:
cpm( cpm(
response.content if hasattr(response, "content") else response, response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"], title="Fallback Model Response: " + fallback_model["model"],
) )
return response return response
except Exception as e: except Exception as e:
if isinstance(e, KeyboardInterrupt): if isinstance(e, KeyboardInterrupt):
raise raise
logger.error( logger.error(
f"Function-calling fallback with model {fallback_model['model']} failed: {e}" f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
) )
raise Exception("All function-calling fallback models failed") return None
def attempt_fallback_function(self, fallback_model):
"""
Attempt a function-calling fallback by invoking the current failing tool with the given fallback model.
Args:
fallback_model (dict): The fallback model to use.
Returns:
The response from the fallback model invocation, or None if failed.
"""
try:
logger.debug(f"Trying fallback model: {fallback_model['model']}")
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
binded_model = simple_model.bind_tools(
[self.current_tool_to_bind],
tool_choice=self.current_failing_tool_name,
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(self.current_failing_tool_name)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"])
logger.debug(
"Function-calling fallback executed successfully with model: "
+ fallback_model["model"]
)
cpm(
response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"],
)
return response
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
)
return None
def invoke_prompt_tool_call(self, tool_call_request: dict): def invoke_prompt_tool_call(self, tool_call_request: dict):
""" """

View File

@ -1,9 +1,10 @@
import logging import logging
import sys import sys
from typing import Optional from typing import Optional
from rich.console import Console from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.panel import Panel
class PrettyHandler(logging.Handler): class PrettyHandler(logging.Handler):

View File

@ -1,3 +1,5 @@
from langchain_core.tools import BaseTool
from ra_aid.tools import ( from ra_aid.tools import (
ask_expert, ask_expert,
ask_human, ask_human,
@ -61,11 +63,13 @@ def get_read_only_tools(
return tools return tools
def get_all_tools_simple(): def get_all_tools_simple():
"""Return a list containing all available tools using existing group methods.""" """Return a list containing all available tools using existing group methods."""
return get_all_tools() return get_all_tools()
def get_all_tools():
def get_all_tools() -> list[BaseTool]:
"""Return a list containing all available tools from different groups.""" """Return a list containing all available tools from different groups."""
all_tools = [] all_tools = []
all_tools.extend(get_read_only_tools()) all_tools.extend(get_read_only_tools())
@ -176,7 +180,7 @@ def get_implementation_tools(
return tools return tools
def get_web_research_tools(expert_enabled: bool = True) -> list: def get_web_research_tools(expert_enabled: bool = True):
"""Get the list of tools available for web research. """Get the list of tools available for web research.
Args: Args:
@ -196,9 +200,7 @@ def get_web_research_tools(expert_enabled: bool = True) -> list:
return tools return tools
def get_chat_tools( def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False):
expert_enabled: bool = True, web_research_enabled: bool = False
) -> list:
"""Get the list of tools available in chat mode. """Get the list of tools available in chat mode.
Chat mode includes research and implementation capabilities but excludes Chat mode includes research and implementation capabilities but excludes

View File

@ -276,11 +276,19 @@ def test_get_model_token_limit_planner(mock_memory):
token_limit = get_model_token_limit(config, "planner") token_limit = get_model_token_limit(config, "planner")
assert token_limit == 120000 assert token_limit == 120000
# New tests for private helper methods in agent_utils.py # New tests for private helper methods in agent_utils.py
def test_setup_and_restore_interrupt_handling(): def test_setup_and_restore_interrupt_handling():
import signal, threading import signal
from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt
from ra_aid.agent_utils import (
_request_interrupt,
_restore_interrupt_handling,
_setup_interrupt_handling,
)
original_handler = signal.getsignal(signal.SIGINT) original_handler = signal.getsignal(signal.SIGINT)
handler = _setup_interrupt_handling() handler = _setup_interrupt_handling()
# Verify the SIGINT handler is set to _request_interrupt # Verify the SIGINT handler is set to _request_interrupt
@ -289,66 +297,93 @@ def test_setup_and_restore_interrupt_handling():
# Verify the SIGINT handler is restored to the original # Verify the SIGINT handler is restored to the original
assert signal.getsignal(signal.SIGINT) == original_handler assert signal.getsignal(signal.SIGINT) == original_handler
def test_increment_and_decrement_agent_depth(): def test_increment_and_decrement_agent_depth():
from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory from ra_aid.agent_utils import (
_decrement_agent_depth,
_global_memory,
_increment_agent_depth,
)
_global_memory["agent_depth"] = 10 _global_memory["agent_depth"] = 10
_increment_agent_depth() _increment_agent_depth()
assert _global_memory["agent_depth"] == 11 assert _global_memory["agent_depth"] == 11
_decrement_agent_depth() _decrement_agent_depth()
assert _global_memory["agent_depth"] == 10 assert _global_memory["agent_depth"] == 10
def test_run_agent_stream(monkeypatch): def test_run_agent_stream(monkeypatch):
from ra_aid.agent_utils import _run_agent_stream, _global_memory from ra_aid.agent_utils import _global_memory, _run_agent_stream
# Create a dummy agent that yields one chunk # Create a dummy agent that yields one chunk
class DummyAgent: class DummyAgent:
def stream(self, msg, cfg): def stream(self, msg, cfg):
yield {"content": "chunk1"} yield {"content": "chunk1"}
dummy_agent = DummyAgent() dummy_agent = DummyAgent()
# Set flags so that _run_agent_stream will reset them # Set flags so that _run_agent_stream will reset them
_global_memory["plan_completed"] = True _global_memory["plan_completed"] = True
_global_memory["task_completed"] = True _global_memory["task_completed"] = True
_global_memory["completion_message"] = "existing" _global_memory["completion_message"] = "existing"
call_flag = {"called": False} call_flag = {"called": False}
def fake_print_agent_output(chunk): def fake_print_agent_output(chunk):
call_flag["called"] = True call_flag["called"] = True
monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output)
monkeypatch.setattr(
"ra_aid.agent_utils.print_agent_output", fake_print_agent_output
)
_run_agent_stream(dummy_agent, "dummy prompt", {}) _run_agent_stream(dummy_agent, "dummy prompt", {})
assert call_flag["called"] assert call_flag["called"]
assert _global_memory["plan_completed"] is False assert _global_memory["plan_completed"] is False
assert _global_memory["task_completed"] is False assert _global_memory["task_completed"] is False
assert _global_memory["completion_message"] == "" assert _global_memory["completion_message"] == ""
def test_execute_test_command_wrapper(monkeypatch): def test_execute_test_command_wrapper(monkeypatch):
from ra_aid.agent_utils import _execute_test_command_wrapper from ra_aid.agent_utils import _execute_test_command_wrapper
# Patch execute_test_command to return a testable tuple # Patch execute_test_command to return a testable tuple
def fake_execute(config, orig, tests, auto): def fake_execute(config, orig, tests, auto):
return (True, "new prompt", auto, tests + 1) return (True, "new prompt", auto, tests + 1)
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute) monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
result = _execute_test_command_wrapper("orig", {}, 0, False) result = _execute_test_command_wrapper("orig", {}, 0, False)
assert result == (True, "new prompt", False, 1) assert result == (True, "new prompt", False, 1)
def test_handle_api_error_valueerror(): def test_handle_api_error_valueerror():
from ra_aid.agent_utils import _handle_api_error
import pytest import pytest
from ra_aid.agent_utils import _handle_api_error
# ValueError not containing "code" or "429" should be re-raised # ValueError not containing "code" or "429" should be re-raised
with pytest.raises(ValueError): with pytest.raises(ValueError):
_handle_api_error(ValueError("some error"), 0, 5, 1) _handle_api_error(ValueError("some error"), 0, 5, 1)
def test_handle_api_error_max_retries(): def test_handle_api_error_max_retries():
from ra_aid.agent_utils import _handle_api_error
import pytest import pytest
from ra_aid.agent_utils import _handle_api_error
# When attempt reaches max retries, a RuntimeError should be raised # When attempt reaches max retries, a RuntimeError should be raised
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
_handle_api_error(Exception("error code 429"), 4, 5, 1) _handle_api_error(Exception("error code 429"), 4, 5, 1)
def test_handle_api_error_retry(monkeypatch): def test_handle_api_error_retry(monkeypatch):
from ra_aid.agent_utils import _handle_api_error
import time import time
from ra_aid.agent_utils import _handle_api_error
# Patch time.monotonic and time.sleep to simulate immediate delay expiration # Patch time.monotonic and time.sleep to simulate immediate delay expiration
fake_time = [0] fake_time = [0]
def fake_monotonic(): def fake_monotonic():
fake_time[0] += 0.5 fake_time[0] += 0.5
return fake_time[0] return fake_time[0]
monkeypatch.setattr(time, "monotonic", fake_monotonic) monkeypatch.setattr(time, "monotonic", fake_monotonic)
monkeypatch.setattr(time, "sleep", lambda s: None) monkeypatch.setattr(time, "sleep", lambda s: None)
# Should not raise error when attempt is lower than max retries # Should not raise error when attempt is lower than max retries

View File

@ -1,28 +1,41 @@
import unittest import unittest
from ra_aid.fallback_handler import FallbackHandler from ra_aid.fallback_handler import FallbackHandler
class DummyLogger: class DummyLogger:
def debug(self, msg): def debug(self, msg):
pass pass
def error(self, msg): def error(self, msg):
pass pass
class DummyAgent: class DummyAgent:
provider = "openai" provider = "openai"
tools = [] tools = []
model = None model = None
class TestFallbackHandler(unittest.TestCase): class TestFallbackHandler(unittest.TestCase):
def setUp(self): def setUp(self):
self.config = {"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"} self.config = {
"max_tool_failures": 2,
"fallback_tool_models": "dummy-fallback-model",
}
self.fallback_handler = FallbackHandler(self.config) self.fallback_handler = FallbackHandler(self.config)
self.logger = DummyLogger() self.logger = DummyLogger()
self.agent = DummyAgent() self.agent = DummyAgent()
def test_handle_failure_increments_counter(self): def test_handle_failure_increments_counter(self):
initial_failures = self.fallback_handler.tool_failure_consecutive_failures initial_failures = self.fallback_handler.tool_failure_consecutive_failures
self.fallback_handler.handle_failure("dummy_call()", Exception("Test error"), self.logger, self.agent) self.fallback_handler.handle_failure(
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1) "dummy_call()", Exception("Test error"), self.logger, self.agent
)
self.assertEqual(
self.fallback_handler.tool_failure_consecutive_failures,
initial_failures + 1,
)
def test_attempt_fallback_resets_counter(self): def test_attempt_fallback_resets_counter(self):
# Monkey-patch dummy functions for fallback components # Monkey-patch dummy functions for fallback components
@ -30,6 +43,7 @@ class TestFallbackHandler(unittest.TestCase):
class DummyModel: class DummyModel:
def bind_tools(self, tools, tool_choice): def bind_tools(self, tools, tool_choice):
pass pass
return DummyModel() return DummyModel()
def dummy_merge_chat_history(): def dummy_merge_chat_history():
@ -39,6 +53,7 @@ class TestFallbackHandler(unittest.TestCase):
return True return True
import ra_aid.llm as llm import ra_aid.llm as llm
original_initialize = llm.initialize_llm original_initialize = llm.initialize_llm
original_merge = llm.merge_chat_history original_merge = llm.merge_chat_history
original_validate = llm.validate_provider_env original_validate = llm.validate_provider_env
@ -47,12 +62,15 @@ class TestFallbackHandler(unittest.TestCase):
llm.validate_provider_env = dummy_validate_provider_env llm.validate_provider_env = dummy_validate_provider_env
self.fallback_handler.tool_failure_consecutive_failures = 2 self.fallback_handler.tool_failure_consecutive_failures = 2
self.fallback_handler.attempt_fallback("dummy_tool_call()", self.logger, self.agent) self.fallback_handler.attempt_fallback(
"dummy_tool_call()", self.logger, self.agent
)
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0) self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
llm.initialize_llm = original_initialize llm.initialize_llm = original_initialize
llm.merge_chat_history = original_merge llm.merge_chat_history = original_merge
llm.validate_provider_env = original_validate llm.validate_provider_env = original_validate
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()