refactor(agent_utils.py): remove the _handle_tool_execution_error function and simplify error handling in run_agent_with_retry
feat(fallback_handler.py): enhance handle_failure method to extract tool name from ToolExecutionError and improve fallback logic fix(exceptions.py): update ToolExecutionError to include base_message for better error context feat(output.py): add base_message to ToolExecutionError for improved debugging chore(tool_configs.py): update get_all_tools function to specify return type style(logging_config.py): reorder imports for consistency test(tests): add tests for new error handling and fallback logic in agent_utils and fallback_handler
This commit is contained in:
parent
a7322eaef2
commit
af9f95ceb1
|
|
@ -9,7 +9,6 @@ import uuid
|
|||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
import litellm
|
||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
|
@ -20,6 +19,7 @@ from langchain_core.messages import (
|
|||
)
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from litellm import get_model_info
|
||||
|
|
@ -876,9 +876,7 @@ def run_agent_with_retry(
|
|||
logger.debug("Agent run completed successfully")
|
||||
return "Agent run completed successfully"
|
||||
except ToolExecutionError as e:
|
||||
fallback_response = _handle_tool_execution_error(
|
||||
fallback_handler, agent, e
|
||||
)
|
||||
fallback_response = fallback_handler.handle_failure(e, agent)
|
||||
if fallback_response:
|
||||
prompt = original_prompt + "\n" + fallback_response
|
||||
continue
|
||||
|
|
@ -895,42 +893,3 @@ def run_agent_with_retry(
|
|||
finally:
|
||||
_decrement_agent_depth()
|
||||
_restore_interrupt_handling(original_handler)
|
||||
|
||||
|
||||
def _handle_tool_execution_error(
|
||||
fallback_handler: FallbackHandler,
|
||||
agent: CiaynAgent | CompiledGraph,
|
||||
error: ToolExecutionError,
|
||||
):
|
||||
logger.debug("Entering _handle_tool_execution_error with error: %s", error)
|
||||
if error.tool_name:
|
||||
failed_tool_call_name = error.tool_name
|
||||
logger.debug(
|
||||
"Extracted failed_tool_call_name from error.tool_name: %s",
|
||||
failed_tool_call_name,
|
||||
)
|
||||
else:
|
||||
import re
|
||||
|
||||
msg = str(error)
|
||||
logger.debug("Error message: %s", msg)
|
||||
match = re.search(r"name=['\"](\w+)['\"]", msg)
|
||||
if match:
|
||||
failed_tool_call_name = match.group(1)
|
||||
logger.debug(
|
||||
"Extracted failed_tool_call_name using regex: %s", failed_tool_call_name
|
||||
)
|
||||
else:
|
||||
failed_tool_call_name = "Tool execution error"
|
||||
logger.debug(
|
||||
"Defaulting failed_tool_call_name to: %s", failed_tool_call_name
|
||||
)
|
||||
logger.debug(
|
||||
"Calling fallback_handler.handle_failure with failed_tool_call_name: %s",
|
||||
failed_tool_call_name,
|
||||
)
|
||||
fallback_response = fallback_handler.handle_failure(
|
||||
failed_tool_call_name, error, agent
|
||||
)
|
||||
logger.debug("Fallback response received: %s", fallback_response)
|
||||
return fallback_response
|
||||
|
|
|
|||
|
|
@ -44,7 +44,9 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
|
|||
)
|
||||
)
|
||||
tool_name = getattr(msg, "name", None)
|
||||
raise ToolExecutionError(err_msg, tool_name=tool_name)
|
||||
cpm(f"type(msg): {type(msg)}")
|
||||
cpm(f"msg: {msg}")
|
||||
raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg)
|
||||
|
||||
|
||||
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
"""Custom exceptions for RA.Aid."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
class AgentInterrupt(Exception):
|
||||
"""Exception raised when an agent's execution is interrupted.
|
||||
|
|
@ -17,6 +21,13 @@ class ToolExecutionError(Exception):
|
|||
This exception is used to distinguish tool execution failures
|
||||
from other types of errors in the agent system.
|
||||
"""
|
||||
def __init__(self, message, tool_name=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
base_message: Optional[BaseMessage] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.base_message = base_message
|
||||
self.tool_name = tool_name
|
||||
|
|
|
|||
|
|
@ -1,20 +1,22 @@
|
|||
import json
|
||||
import re
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
from langgraph.graph.message import BaseMessage
|
||||
|
||||
from ra_aid.console.output import cpm
|
||||
import json
|
||||
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from ra_aid.config import (
|
||||
DEFAULT_MAX_TOOL_FAILURES,
|
||||
FALLBACK_TOOL_MODEL_LIMIT,
|
||||
RETRY_FALLBACK_COUNT,
|
||||
)
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.tool_leaderboard import supported_top_tool_models
|
||||
from rich.console import Console
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
from ra_aid.llm import initialize_llm, validate_provider_env
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.tool_configs import get_all_tools
|
||||
from ra_aid.tool_leaderboard import supported_top_tool_models
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -41,9 +43,12 @@ class FallbackHandler:
|
|||
self.tools: list[BaseTool] = tools
|
||||
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
||||
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
||||
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.failed_messages = set()
|
||||
self.tool_failure_used_fallbacks = set()
|
||||
self.console = Console()
|
||||
self.current_failing_tool_name = ""
|
||||
self.current_tool_to_bind = None
|
||||
|
||||
def _load_fallback_tool_models(self, config):
|
||||
"""
|
||||
|
|
@ -87,66 +92,104 @@ class FallbackHandler:
|
|||
cpm(message, title="Fallback Models")
|
||||
return final_models
|
||||
|
||||
def handle_failure(
|
||||
self, code: str, error: Exception, agent: CiaynAgent | CompiledGraph
|
||||
):
|
||||
def handle_failure(self, error: Exception, agent: CiaynAgent | CompiledGraph):
|
||||
"""
|
||||
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
||||
|
||||
Args:
|
||||
code (str): The code that failed to execute.
|
||||
error (Exception): The exception raised during execution.
|
||||
logger: Logger instance for logging.
|
||||
error (Exception): The exception raised during execution. If the exception has a 'base_message' attribute, that message is recorded.
|
||||
agent: The agent instance on which fallback may be executed.
|
||||
"""
|
||||
logger.debug(
|
||||
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
|
||||
)
|
||||
self.tool_failure_consecutive_failures += 1
|
||||
max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
||||
logger.debug(
|
||||
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}"
|
||||
)
|
||||
if not self.fallback_enabled:
|
||||
return None
|
||||
|
||||
failed_tool_call_name = self.extract_failed_tool_name(error)
|
||||
if (
|
||||
self.fallback_enabled
|
||||
and self.tool_failure_consecutive_failures >= max_failures
|
||||
self.current_failing_tool_name
|
||||
and failed_tool_call_name != self.current_failing_tool_name
|
||||
):
|
||||
logger.debug(
|
||||
"New failing tool name identified. Resetting consecutive tool failures."
|
||||
)
|
||||
self.reset_fallback_handler()
|
||||
|
||||
logger.debug(
|
||||
f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}"
|
||||
)
|
||||
|
||||
self.current_failing_tool_name = failed_tool_call_name
|
||||
self.current_tool_to_bind = self._find_tool_to_bind(
|
||||
agent, failed_tool_call_name
|
||||
)
|
||||
|
||||
if hasattr(error, "base_message") and error.base_message:
|
||||
self.failed_messages.add(str(error.base_message))
|
||||
|
||||
self.tool_failure_consecutive_failures += 1
|
||||
logger.debug(
|
||||
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {self.max_failures}"
|
||||
)
|
||||
|
||||
if (
|
||||
self.tool_failure_consecutive_failures >= self.max_failures
|
||||
and self.fallback_tool_models
|
||||
):
|
||||
logger.debug(
|
||||
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
|
||||
)
|
||||
return self.attempt_fallback(code, logger, agent)
|
||||
return self.attempt_fallback()
|
||||
|
||||
def attempt_fallback(self, code: str, logger, agent):
|
||||
def attempt_fallback(self):
|
||||
"""
|
||||
Initiate the fallback process by selecting a fallback model and triggering the appropriate fallback method.
|
||||
Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method.
|
||||
|
||||
Args:
|
||||
code (str): The tool code that triggered the fallback.
|
||||
logger: Logger instance for logging messages.
|
||||
agent: The agent for which fallback is being executed.
|
||||
Returns:
|
||||
The response from a fallback model if any, otherwise None.
|
||||
"""
|
||||
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
|
||||
fallback_model = self.fallback_tool_models[0]
|
||||
failed_tool_call_name = code
|
||||
logger.error(
|
||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
|
||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
|
||||
)
|
||||
cpm(
|
||||
f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}.",
|
||||
f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
|
||||
title="Fallback Notification",
|
||||
)
|
||||
if fallback_model.get("type", "prompt").lower() == "fc":
|
||||
self.attempt_fallback_function(code, logger, agent)
|
||||
else:
|
||||
self.attempt_fallback_prompt(code, logger, agent)
|
||||
for fallback_model in self.fallback_tool_models:
|
||||
if fallback_model.get("type", "prompt").lower() == "fc":
|
||||
response = self.attempt_fallback_function(fallback_model)
|
||||
else:
|
||||
response = self.attempt_fallback_prompt(fallback_model)
|
||||
if response:
|
||||
return response
|
||||
cpm("All fallback models have failed", title="Fallback Failed")
|
||||
return None
|
||||
|
||||
def reset_fallback_handler(self):
|
||||
"""
|
||||
Reset the fallback handler's internal failure counters and clear the record of used fallback models.
|
||||
"""
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.failed_messages.clear()
|
||||
self.tool_failure_used_fallbacks.clear()
|
||||
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
|
||||
|
||||
def extract_failed_tool_name(self, error: ToolExecutionError):
|
||||
if error.tool_name:
|
||||
failed_tool_call_name = error.tool_name
|
||||
else:
|
||||
msg = str(error)
|
||||
logger.debug("Error message: %s", msg)
|
||||
match = re.search(r"name=['\"](\w+)['\"]", msg)
|
||||
if match:
|
||||
failed_tool_call_name = str(match.group(1))
|
||||
logger.debug(
|
||||
"Extracted failed_tool_call_name using regex: %s",
|
||||
failed_tool_call_name,
|
||||
)
|
||||
else:
|
||||
failed_tool_call_name = "Tool execution error"
|
||||
raise Exception("Fallback failed: Could not extract failed tool name.")
|
||||
|
||||
return failed_tool_call_name
|
||||
|
||||
def _find_tool_to_bind(self, agent, failed_tool_call_name):
|
||||
logger.debug(f"failed_tool_call_name={failed_tool_call_name}")
|
||||
|
|
@ -157,135 +200,108 @@ class FallbackHandler:
|
|||
None,
|
||||
)
|
||||
if tool_to_bind is None:
|
||||
from ra_aid.tool_configs import get_all_tools
|
||||
|
||||
all_tools = get_all_tools()
|
||||
tool_to_bind = next(
|
||||
(t for t in all_tools if t.func.__name__ == failed_tool_call_name),
|
||||
None,
|
||||
)
|
||||
if tool_to_bind is None:
|
||||
available = [t.func.__name__ for t in get_all_tools()]
|
||||
logger.debug(
|
||||
f"Failed to find tool: {failed_tool_call_name}. Available tools: {available}"
|
||||
# TODO: Would be nice to try fuzzy match or levenstein str match to find closest correspond tool name
|
||||
raise Exception(
|
||||
f"Fallback failed: {failed_tool_call_name} not found in all tools."
|
||||
)
|
||||
raise Exception(f"Tool {failed_tool_call_name} not found in all tools.")
|
||||
return tool_to_bind
|
||||
|
||||
def attempt_fallback_prompt(self, code: str, logger, agent):
|
||||
def attempt_fallback_prompt(self, fallback_model):
|
||||
"""
|
||||
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
|
||||
|
||||
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
|
||||
Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model.
|
||||
|
||||
Args:
|
||||
code (str): The tool code to invoke via fallback.
|
||||
logger: Logger instance for logging messages.
|
||||
agent: The agent instance to update with the new model upon success.
|
||||
fallback_model (dict): The fallback model to use.
|
||||
|
||||
Returns:
|
||||
The response from the fallback model invocation.
|
||||
|
||||
Raises:
|
||||
Exception: If all prompt-based fallback models fail.
|
||||
The response from the fallback model invocation, or None if failed.
|
||||
"""
|
||||
logger.debug("Attempting prompt-based fallback using fallback models")
|
||||
failed_tool_call_name = code
|
||||
for fallback_model in self.fallback_tool_models:
|
||||
try:
|
||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||
simple_model = initialize_llm(
|
||||
fallback_model["provider"], fallback_model["model"]
|
||||
)
|
||||
tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name)
|
||||
binded_model = simple_model.bind_tools(
|
||||
[tool_to_bind], tool_choice=failed_tool_call_name
|
||||
)
|
||||
# retry_model = binded_model.with_retry(
|
||||
# stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||
# )
|
||||
response = binded_model.invoke(code)
|
||||
cpm(f"response={response}")
|
||||
|
||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||
|
||||
tool_call = self.base_message_to_tool_call_dict(response)
|
||||
if tool_call:
|
||||
result = self.invoke_prompt_tool_call(tool_call)
|
||||
cpm(f"result={result}")
|
||||
logger.debug(
|
||||
"Prompt-based fallback executed successfully with model: "
|
||||
+ fallback_model["model"]
|
||||
)
|
||||
self.reset_fallback_handler()
|
||||
return result
|
||||
else:
|
||||
cpm(
|
||||
response.content if hasattr(response, "content") else response,
|
||||
title="Fallback Model Response: " + fallback_model["model"],
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
if isinstance(e, KeyboardInterrupt):
|
||||
raise
|
||||
logger.error(
|
||||
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
|
||||
)
|
||||
raise Exception("All prompt-based fallback models failed")
|
||||
|
||||
def attempt_fallback_function(self, code: str, logger, agent):
|
||||
"""
|
||||
Attempt a function-calling fallback by iterating over fallback models and invoking the provided code.
|
||||
|
||||
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
|
||||
|
||||
Args:
|
||||
code (str): The tool code to invoke via fallback.
|
||||
logger: Logger instance for logging messages.
|
||||
agent: The agent instance to update with the new model upon success.
|
||||
|
||||
Returns:
|
||||
The response from the fallback model invocation.
|
||||
|
||||
Raises:
|
||||
Exception: If all function-calling fallback models fail.
|
||||
"""
|
||||
logger.debug("Attempting function-calling fallback using fallback models")
|
||||
failed_tool_call_name = code
|
||||
for fallback_model in self.fallback_tool_models:
|
||||
try:
|
||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||
simple_model = initialize_llm(
|
||||
fallback_model["provider"], fallback_model["model"]
|
||||
)
|
||||
tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name)
|
||||
binded_model = simple_model.bind_tools(
|
||||
[tool_to_bind], tool_choice=failed_tool_call_name
|
||||
)
|
||||
retry_model = binded_model.with_retry(
|
||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||
)
|
||||
response = retry_model.invoke(code)
|
||||
cpm(f"response={response}")
|
||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||
self.reset_fallback_handler()
|
||||
try:
|
||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||
simple_model = initialize_llm(
|
||||
fallback_model["provider"], fallback_model["model"]
|
||||
)
|
||||
binded_model = simple_model.bind_tools(
|
||||
[self.current_tool_to_bind],
|
||||
tool_choice=self.current_failing_tool_name,
|
||||
)
|
||||
retry_model = binded_model.with_retry(
|
||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||
)
|
||||
response = retry_model.invoke(self.current_failing_tool_name)
|
||||
cpm(f"response={response}")
|
||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||
tool_call = self.base_message_to_tool_call_dict(response)
|
||||
if tool_call:
|
||||
result = self.invoke_prompt_tool_call(tool_call)
|
||||
cpm(f"result={result}")
|
||||
logger.debug(
|
||||
"Function-calling fallback executed successfully with model: "
|
||||
"Prompt-based fallback executed successfully with model: "
|
||||
+ fallback_model["model"]
|
||||
)
|
||||
|
||||
self.reset_fallback_handler()
|
||||
return result
|
||||
else:
|
||||
cpm(
|
||||
response.content if hasattr(response, "content") else response,
|
||||
title="Fallback Model Response: " + fallback_model["model"],
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
if isinstance(e, KeyboardInterrupt):
|
||||
raise
|
||||
logger.error(
|
||||
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
|
||||
)
|
||||
raise Exception("All function-calling fallback models failed")
|
||||
except Exception as e:
|
||||
if isinstance(e, KeyboardInterrupt):
|
||||
raise
|
||||
logger.error(
|
||||
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def attempt_fallback_function(self, fallback_model):
|
||||
"""
|
||||
Attempt a function-calling fallback by invoking the current failing tool with the given fallback model.
|
||||
|
||||
Args:
|
||||
fallback_model (dict): The fallback model to use.
|
||||
|
||||
Returns:
|
||||
The response from the fallback model invocation, or None if failed.
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||
simple_model = initialize_llm(
|
||||
fallback_model["provider"], fallback_model["model"]
|
||||
)
|
||||
binded_model = simple_model.bind_tools(
|
||||
[self.current_tool_to_bind],
|
||||
tool_choice=self.current_failing_tool_name,
|
||||
)
|
||||
retry_model = binded_model.with_retry(
|
||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||
)
|
||||
response = retry_model.invoke(self.current_failing_tool_name)
|
||||
cpm(f"response={response}")
|
||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||
logger.debug(
|
||||
"Function-calling fallback executed successfully with model: "
|
||||
+ fallback_model["model"]
|
||||
)
|
||||
cpm(
|
||||
response.content if hasattr(response, "content") else response,
|
||||
title="Fallback Model Response: " + fallback_model["model"],
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
if isinstance(e, KeyboardInterrupt):
|
||||
raise
|
||||
logger.error(
|
||||
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def invoke_prompt_tool_call(self, tool_call_request: dict):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
|
||||
class PrettyHandler(logging.Handler):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from langchain_core.tools import BaseTool
|
||||
|
||||
from ra_aid.tools import (
|
||||
ask_expert,
|
||||
ask_human,
|
||||
|
|
@ -61,11 +63,13 @@ def get_read_only_tools(
|
|||
|
||||
return tools
|
||||
|
||||
|
||||
def get_all_tools_simple():
|
||||
"""Return a list containing all available tools using existing group methods."""
|
||||
return get_all_tools()
|
||||
|
||||
def get_all_tools():
|
||||
|
||||
def get_all_tools() -> list[BaseTool]:
|
||||
"""Return a list containing all available tools from different groups."""
|
||||
all_tools = []
|
||||
all_tools.extend(get_read_only_tools())
|
||||
|
|
@ -176,7 +180,7 @@ def get_implementation_tools(
|
|||
return tools
|
||||
|
||||
|
||||
def get_web_research_tools(expert_enabled: bool = True) -> list:
|
||||
def get_web_research_tools(expert_enabled: bool = True):
|
||||
"""Get the list of tools available for web research.
|
||||
|
||||
Args:
|
||||
|
|
@ -196,9 +200,7 @@ def get_web_research_tools(expert_enabled: bool = True) -> list:
|
|||
return tools
|
||||
|
||||
|
||||
def get_chat_tools(
|
||||
expert_enabled: bool = True, web_research_enabled: bool = False
|
||||
) -> list:
|
||||
def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False):
|
||||
"""Get the list of tools available in chat mode.
|
||||
|
||||
Chat mode includes research and implementation capabilities but excludes
|
||||
|
|
|
|||
|
|
@ -276,11 +276,19 @@ def test_get_model_token_limit_planner(mock_memory):
|
|||
token_limit = get_model_token_limit(config, "planner")
|
||||
assert token_limit == 120000
|
||||
|
||||
|
||||
# New tests for private helper methods in agent_utils.py
|
||||
|
||||
|
||||
def test_setup_and_restore_interrupt_handling():
|
||||
import signal, threading
|
||||
from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt
|
||||
import signal
|
||||
|
||||
from ra_aid.agent_utils import (
|
||||
_request_interrupt,
|
||||
_restore_interrupt_handling,
|
||||
_setup_interrupt_handling,
|
||||
)
|
||||
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
handler = _setup_interrupt_handling()
|
||||
# Verify the SIGINT handler is set to _request_interrupt
|
||||
|
|
@ -289,66 +297,93 @@ def test_setup_and_restore_interrupt_handling():
|
|||
# Verify the SIGINT handler is restored to the original
|
||||
assert signal.getsignal(signal.SIGINT) == original_handler
|
||||
|
||||
|
||||
def test_increment_and_decrement_agent_depth():
|
||||
from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory
|
||||
from ra_aid.agent_utils import (
|
||||
_decrement_agent_depth,
|
||||
_global_memory,
|
||||
_increment_agent_depth,
|
||||
)
|
||||
|
||||
_global_memory["agent_depth"] = 10
|
||||
_increment_agent_depth()
|
||||
assert _global_memory["agent_depth"] == 11
|
||||
_decrement_agent_depth()
|
||||
assert _global_memory["agent_depth"] == 10
|
||||
|
||||
|
||||
def test_run_agent_stream(monkeypatch):
|
||||
from ra_aid.agent_utils import _run_agent_stream, _global_memory
|
||||
from ra_aid.agent_utils import _global_memory, _run_agent_stream
|
||||
|
||||
# Create a dummy agent that yields one chunk
|
||||
class DummyAgent:
|
||||
def stream(self, msg, cfg):
|
||||
yield {"content": "chunk1"}
|
||||
|
||||
dummy_agent = DummyAgent()
|
||||
# Set flags so that _run_agent_stream will reset them
|
||||
_global_memory["plan_completed"] = True
|
||||
_global_memory["task_completed"] = True
|
||||
_global_memory["completion_message"] = "existing"
|
||||
call_flag = {"called": False}
|
||||
|
||||
def fake_print_agent_output(chunk):
|
||||
call_flag["called"] = True
|
||||
monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils.print_agent_output", fake_print_agent_output
|
||||
)
|
||||
_run_agent_stream(dummy_agent, "dummy prompt", {})
|
||||
assert call_flag["called"]
|
||||
assert _global_memory["plan_completed"] is False
|
||||
assert _global_memory["task_completed"] is False
|
||||
assert _global_memory["completion_message"] == ""
|
||||
|
||||
|
||||
def test_execute_test_command_wrapper(monkeypatch):
|
||||
from ra_aid.agent_utils import _execute_test_command_wrapper
|
||||
|
||||
# Patch execute_test_command to return a testable tuple
|
||||
def fake_execute(config, orig, tests, auto):
|
||||
return (True, "new prompt", auto, tests + 1)
|
||||
|
||||
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
|
||||
result = _execute_test_command_wrapper("orig", {}, 0, False)
|
||||
assert result == (True, "new prompt", False, 1)
|
||||
|
||||
|
||||
def test_handle_api_error_valueerror():
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
|
||||
# ValueError not containing "code" or "429" should be re-raised
|
||||
with pytest.raises(ValueError):
|
||||
_handle_api_error(ValueError("some error"), 0, 5, 1)
|
||||
|
||||
|
||||
def test_handle_api_error_max_retries():
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
|
||||
# When attempt reaches max retries, a RuntimeError should be raised
|
||||
with pytest.raises(RuntimeError):
|
||||
_handle_api_error(Exception("error code 429"), 4, 5, 1)
|
||||
|
||||
|
||||
def test_handle_api_error_retry(monkeypatch):
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
import time
|
||||
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
|
||||
# Patch time.monotonic and time.sleep to simulate immediate delay expiration
|
||||
fake_time = [0]
|
||||
|
||||
def fake_monotonic():
|
||||
fake_time[0] += 0.5
|
||||
return fake_time[0]
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", fake_monotonic)
|
||||
monkeypatch.setattr(time, "sleep", lambda s: None)
|
||||
# Should not raise error when attempt is lower than max retries
|
||||
|
|
|
|||
|
|
@ -1,28 +1,41 @@
|
|||
import unittest
|
||||
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def debug(self, msg):
|
||||
pass
|
||||
|
||||
def error(self, msg):
|
||||
pass
|
||||
|
||||
|
||||
class DummyAgent:
|
||||
provider = "openai"
|
||||
tools = []
|
||||
model = None
|
||||
|
||||
|
||||
class TestFallbackHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = {"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"}
|
||||
self.config = {
|
||||
"max_tool_failures": 2,
|
||||
"fallback_tool_models": "dummy-fallback-model",
|
||||
}
|
||||
self.fallback_handler = FallbackHandler(self.config)
|
||||
self.logger = DummyLogger()
|
||||
self.agent = DummyAgent()
|
||||
|
||||
def test_handle_failure_increments_counter(self):
|
||||
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
|
||||
self.fallback_handler.handle_failure("dummy_call()", Exception("Test error"), self.logger, self.agent)
|
||||
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1)
|
||||
self.fallback_handler.handle_failure(
|
||||
"dummy_call()", Exception("Test error"), self.logger, self.agent
|
||||
)
|
||||
self.assertEqual(
|
||||
self.fallback_handler.tool_failure_consecutive_failures,
|
||||
initial_failures + 1,
|
||||
)
|
||||
|
||||
def test_attempt_fallback_resets_counter(self):
|
||||
# Monkey-patch dummy functions for fallback components
|
||||
|
|
@ -30,6 +43,7 @@ class TestFallbackHandler(unittest.TestCase):
|
|||
class DummyModel:
|
||||
def bind_tools(self, tools, tool_choice):
|
||||
pass
|
||||
|
||||
return DummyModel()
|
||||
|
||||
def dummy_merge_chat_history():
|
||||
|
|
@ -39,6 +53,7 @@ class TestFallbackHandler(unittest.TestCase):
|
|||
return True
|
||||
|
||||
import ra_aid.llm as llm
|
||||
|
||||
original_initialize = llm.initialize_llm
|
||||
original_merge = llm.merge_chat_history
|
||||
original_validate = llm.validate_provider_env
|
||||
|
|
@ -47,12 +62,15 @@ class TestFallbackHandler(unittest.TestCase):
|
|||
llm.validate_provider_env = dummy_validate_provider_env
|
||||
|
||||
self.fallback_handler.tool_failure_consecutive_failures = 2
|
||||
self.fallback_handler.attempt_fallback("dummy_tool_call()", self.logger, self.agent)
|
||||
self.fallback_handler.attempt_fallback(
|
||||
"dummy_tool_call()", self.logger, self.agent
|
||||
)
|
||||
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
|
||||
|
||||
llm.initialize_llm = original_initialize
|
||||
llm.merge_chat_history = original_merge
|
||||
llm.validate_provider_env = original_validate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue