refactor(agent_utils.py): remove the _handle_tool_execution_error function and simplify error handling in run_agent_with_retry
feat(fallback_handler.py): enhance handle_failure method to extract tool name from ToolExecutionError and improve fallback logic fix(exceptions.py): update ToolExecutionError to include base_message for better error context feat(output.py): add base_message to ToolExecutionError for improved debugging chore(tool_configs.py): update get_all_tools function to specify return type style(logging_config.py): reorder imports for consistency test(tests): add tests for new error handling and fallback logic in agent_utils and fallback_handler
This commit is contained in:
parent
a7322eaef2
commit
af9f95ceb1
|
|
@ -9,7 +9,6 @@ import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
from langgraph.graph.graph import CompiledGraph
|
|
||||||
import litellm
|
import litellm
|
||||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
@ -20,6 +19,7 @@ from langchain_core.messages import (
|
||||||
)
|
)
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
from langgraph.graph.graph import CompiledGraph
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
from litellm import get_model_info
|
from litellm import get_model_info
|
||||||
|
|
@ -876,9 +876,7 @@ def run_agent_with_retry(
|
||||||
logger.debug("Agent run completed successfully")
|
logger.debug("Agent run completed successfully")
|
||||||
return "Agent run completed successfully"
|
return "Agent run completed successfully"
|
||||||
except ToolExecutionError as e:
|
except ToolExecutionError as e:
|
||||||
fallback_response = _handle_tool_execution_error(
|
fallback_response = fallback_handler.handle_failure(e, agent)
|
||||||
fallback_handler, agent, e
|
|
||||||
)
|
|
||||||
if fallback_response:
|
if fallback_response:
|
||||||
prompt = original_prompt + "\n" + fallback_response
|
prompt = original_prompt + "\n" + fallback_response
|
||||||
continue
|
continue
|
||||||
|
|
@ -895,42 +893,3 @@ def run_agent_with_retry(
|
||||||
finally:
|
finally:
|
||||||
_decrement_agent_depth()
|
_decrement_agent_depth()
|
||||||
_restore_interrupt_handling(original_handler)
|
_restore_interrupt_handling(original_handler)
|
||||||
|
|
||||||
|
|
||||||
def _handle_tool_execution_error(
|
|
||||||
fallback_handler: FallbackHandler,
|
|
||||||
agent: CiaynAgent | CompiledGraph,
|
|
||||||
error: ToolExecutionError,
|
|
||||||
):
|
|
||||||
logger.debug("Entering _handle_tool_execution_error with error: %s", error)
|
|
||||||
if error.tool_name:
|
|
||||||
failed_tool_call_name = error.tool_name
|
|
||||||
logger.debug(
|
|
||||||
"Extracted failed_tool_call_name from error.tool_name: %s",
|
|
||||||
failed_tool_call_name,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
import re
|
|
||||||
|
|
||||||
msg = str(error)
|
|
||||||
logger.debug("Error message: %s", msg)
|
|
||||||
match = re.search(r"name=['\"](\w+)['\"]", msg)
|
|
||||||
if match:
|
|
||||||
failed_tool_call_name = match.group(1)
|
|
||||||
logger.debug(
|
|
||||||
"Extracted failed_tool_call_name using regex: %s", failed_tool_call_name
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
failed_tool_call_name = "Tool execution error"
|
|
||||||
logger.debug(
|
|
||||||
"Defaulting failed_tool_call_name to: %s", failed_tool_call_name
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Calling fallback_handler.handle_failure with failed_tool_call_name: %s",
|
|
||||||
failed_tool_call_name,
|
|
||||||
)
|
|
||||||
fallback_response = fallback_handler.handle_failure(
|
|
||||||
failed_tool_call_name, error, agent
|
|
||||||
)
|
|
||||||
logger.debug("Fallback response received: %s", fallback_response)
|
|
||||||
return fallback_response
|
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,9 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tool_name = getattr(msg, "name", None)
|
tool_name = getattr(msg, "name", None)
|
||||||
raise ToolExecutionError(err_msg, tool_name=tool_name)
|
cpm(f"type(msg): {type(msg)}")
|
||||||
|
cpm(f"msg: {msg}")
|
||||||
|
raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg)
|
||||||
|
|
||||||
|
|
||||||
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:
|
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,9 @@
|
||||||
"""Custom exceptions for RA.Aid."""
|
"""Custom exceptions for RA.Aid."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
class AgentInterrupt(Exception):
|
class AgentInterrupt(Exception):
|
||||||
"""Exception raised when an agent's execution is interrupted.
|
"""Exception raised when an agent's execution is interrupted.
|
||||||
|
|
@ -17,6 +21,13 @@ class ToolExecutionError(Exception):
|
||||||
This exception is used to distinguish tool execution failures
|
This exception is used to distinguish tool execution failures
|
||||||
from other types of errors in the agent system.
|
from other types of errors in the agent system.
|
||||||
"""
|
"""
|
||||||
def __init__(self, message, tool_name=None):
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
base_message: Optional[BaseMessage] = None,
|
||||||
|
tool_name: Optional[str] = None,
|
||||||
|
):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
self.base_message = base_message
|
||||||
self.tool_name = tool_name
|
self.tool_name = tool_name
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,22 @@
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.graph.graph import CompiledGraph
|
from langgraph.graph.graph import CompiledGraph
|
||||||
from langgraph.graph.message import BaseMessage
|
from langgraph.graph.message import BaseMessage
|
||||||
|
|
||||||
from ra_aid.console.output import cpm
|
|
||||||
import json
|
|
||||||
|
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.config import (
|
from ra_aid.config import (
|
||||||
DEFAULT_MAX_TOOL_FAILURES,
|
DEFAULT_MAX_TOOL_FAILURES,
|
||||||
FALLBACK_TOOL_MODEL_LIMIT,
|
FALLBACK_TOOL_MODEL_LIMIT,
|
||||||
RETRY_FALLBACK_COUNT,
|
RETRY_FALLBACK_COUNT,
|
||||||
)
|
)
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.console.output import cpm
|
||||||
from ra_aid.tool_leaderboard import supported_top_tool_models
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
from rich.console import Console
|
|
||||||
from ra_aid.llm import initialize_llm, validate_provider_env
|
from ra_aid.llm import initialize_llm, validate_provider_env
|
||||||
|
from ra_aid.logging_config import get_logger
|
||||||
|
from ra_aid.tool_configs import get_all_tools
|
||||||
|
from ra_aid.tool_leaderboard import supported_top_tool_models
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -41,9 +43,12 @@ class FallbackHandler:
|
||||||
self.tools: list[BaseTool] = tools
|
self.tools: list[BaseTool] = tools
|
||||||
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
||||||
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
||||||
|
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
||||||
self.tool_failure_consecutive_failures = 0
|
self.tool_failure_consecutive_failures = 0
|
||||||
|
self.failed_messages = set()
|
||||||
self.tool_failure_used_fallbacks = set()
|
self.tool_failure_used_fallbacks = set()
|
||||||
self.console = Console()
|
self.current_failing_tool_name = ""
|
||||||
|
self.current_tool_to_bind = None
|
||||||
|
|
||||||
def _load_fallback_tool_models(self, config):
|
def _load_fallback_tool_models(self, config):
|
||||||
"""
|
"""
|
||||||
|
|
@ -87,66 +92,104 @@ class FallbackHandler:
|
||||||
cpm(message, title="Fallback Models")
|
cpm(message, title="Fallback Models")
|
||||||
return final_models
|
return final_models
|
||||||
|
|
||||||
def handle_failure(
|
def handle_failure(self, error: Exception, agent: CiaynAgent | CompiledGraph):
|
||||||
self, code: str, error: Exception, agent: CiaynAgent | CompiledGraph
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
code (str): The code that failed to execute.
|
error (Exception): The exception raised during execution. If the exception has a 'base_message' attribute, that message is recorded.
|
||||||
error (Exception): The exception raised during execution.
|
|
||||||
logger: Logger instance for logging.
|
|
||||||
agent: The agent instance on which fallback may be executed.
|
agent: The agent instance on which fallback may be executed.
|
||||||
"""
|
"""
|
||||||
logger.debug(
|
if not self.fallback_enabled:
|
||||||
f"_handle_tool_failure: tool failure encountered for code '{code}' with error: {error}"
|
return None
|
||||||
)
|
|
||||||
self.tool_failure_consecutive_failures += 1
|
failed_tool_call_name = self.extract_failed_tool_name(error)
|
||||||
max_failures = self.config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
|
||||||
logger.debug(
|
|
||||||
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {max_failures}"
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
self.fallback_enabled
|
self.current_failing_tool_name
|
||||||
and self.tool_failure_consecutive_failures >= max_failures
|
and failed_tool_call_name != self.current_failing_tool_name
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
"New failing tool name identified. Resetting consecutive tool failures."
|
||||||
|
)
|
||||||
|
self.reset_fallback_handler()
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_failing_tool_name = failed_tool_call_name
|
||||||
|
self.current_tool_to_bind = self._find_tool_to_bind(
|
||||||
|
agent, failed_tool_call_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(error, "base_message") and error.base_message:
|
||||||
|
self.failed_messages.add(str(error.base_message))
|
||||||
|
|
||||||
|
self.tool_failure_consecutive_failures += 1
|
||||||
|
logger.debug(
|
||||||
|
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {self.max_failures}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.tool_failure_consecutive_failures >= self.max_failures
|
||||||
and self.fallback_tool_models
|
and self.fallback_tool_models
|
||||||
):
|
):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
|
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
|
||||||
)
|
)
|
||||||
return self.attempt_fallback(code, logger, agent)
|
return self.attempt_fallback()
|
||||||
|
|
||||||
def attempt_fallback(self, code: str, logger, agent):
|
def attempt_fallback(self):
|
||||||
"""
|
"""
|
||||||
Initiate the fallback process by selecting a fallback model and triggering the appropriate fallback method.
|
Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method.
|
||||||
|
|
||||||
Args:
|
Returns:
|
||||||
code (str): The tool code that triggered the fallback.
|
The response from a fallback model if any, otherwise None.
|
||||||
logger: Logger instance for logging messages.
|
|
||||||
agent: The agent for which fallback is being executed.
|
|
||||||
"""
|
"""
|
||||||
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
|
|
||||||
fallback_model = self.fallback_tool_models[0]
|
|
||||||
failed_tool_call_name = code
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
|
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
|
||||||
)
|
)
|
||||||
cpm(
|
cpm(
|
||||||
f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}.",
|
f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
|
||||||
title="Fallback Notification",
|
title="Fallback Notification",
|
||||||
)
|
)
|
||||||
if fallback_model.get("type", "prompt").lower() == "fc":
|
for fallback_model in self.fallback_tool_models:
|
||||||
self.attempt_fallback_function(code, logger, agent)
|
if fallback_model.get("type", "prompt").lower() == "fc":
|
||||||
else:
|
response = self.attempt_fallback_function(fallback_model)
|
||||||
self.attempt_fallback_prompt(code, logger, agent)
|
else:
|
||||||
|
response = self.attempt_fallback_prompt(fallback_model)
|
||||||
|
if response:
|
||||||
|
return response
|
||||||
|
cpm("All fallback models have failed", title="Fallback Failed")
|
||||||
|
return None
|
||||||
|
|
||||||
def reset_fallback_handler(self):
|
def reset_fallback_handler(self):
|
||||||
"""
|
"""
|
||||||
Reset the fallback handler's internal failure counters and clear the record of used fallback models.
|
Reset the fallback handler's internal failure counters and clear the record of used fallback models.
|
||||||
"""
|
"""
|
||||||
self.tool_failure_consecutive_failures = 0
|
self.tool_failure_consecutive_failures = 0
|
||||||
|
self.failed_messages.clear()
|
||||||
self.tool_failure_used_fallbacks.clear()
|
self.tool_failure_used_fallbacks.clear()
|
||||||
|
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
|
||||||
|
|
||||||
|
def extract_failed_tool_name(self, error: ToolExecutionError):
|
||||||
|
if error.tool_name:
|
||||||
|
failed_tool_call_name = error.tool_name
|
||||||
|
else:
|
||||||
|
msg = str(error)
|
||||||
|
logger.debug("Error message: %s", msg)
|
||||||
|
match = re.search(r"name=['\"](\w+)['\"]", msg)
|
||||||
|
if match:
|
||||||
|
failed_tool_call_name = str(match.group(1))
|
||||||
|
logger.debug(
|
||||||
|
"Extracted failed_tool_call_name using regex: %s",
|
||||||
|
failed_tool_call_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
failed_tool_call_name = "Tool execution error"
|
||||||
|
raise Exception("Fallback failed: Could not extract failed tool name.")
|
||||||
|
|
||||||
|
return failed_tool_call_name
|
||||||
|
|
||||||
def _find_tool_to_bind(self, agent, failed_tool_call_name):
|
def _find_tool_to_bind(self, agent, failed_tool_call_name):
|
||||||
logger.debug(f"failed_tool_call_name={failed_tool_call_name}")
|
logger.debug(f"failed_tool_call_name={failed_tool_call_name}")
|
||||||
|
|
@ -157,135 +200,108 @@ class FallbackHandler:
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if tool_to_bind is None:
|
if tool_to_bind is None:
|
||||||
from ra_aid.tool_configs import get_all_tools
|
|
||||||
|
|
||||||
all_tools = get_all_tools()
|
all_tools = get_all_tools()
|
||||||
tool_to_bind = next(
|
tool_to_bind = next(
|
||||||
(t for t in all_tools if t.func.__name__ == failed_tool_call_name),
|
(t for t in all_tools if t.func.__name__ == failed_tool_call_name),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if tool_to_bind is None:
|
if tool_to_bind is None:
|
||||||
available = [t.func.__name__ for t in get_all_tools()]
|
# TODO: Would be nice to try fuzzy match or levenstein str match to find closest correspond tool name
|
||||||
logger.debug(
|
raise Exception(
|
||||||
f"Failed to find tool: {failed_tool_call_name}. Available tools: {available}"
|
f"Fallback failed: {failed_tool_call_name} not found in all tools."
|
||||||
)
|
)
|
||||||
raise Exception(f"Tool {failed_tool_call_name} not found in all tools.")
|
|
||||||
return tool_to_bind
|
return tool_to_bind
|
||||||
|
|
||||||
def attempt_fallback_prompt(self, code: str, logger, agent):
|
def attempt_fallback_prompt(self, fallback_model):
|
||||||
"""
|
"""
|
||||||
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
|
Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model.
|
||||||
|
|
||||||
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
code (str): The tool code to invoke via fallback.
|
fallback_model (dict): The fallback model to use.
|
||||||
logger: Logger instance for logging messages.
|
|
||||||
agent: The agent instance to update with the new model upon success.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The response from the fallback model invocation.
|
The response from the fallback model invocation, or None if failed.
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If all prompt-based fallback models fail.
|
|
||||||
"""
|
"""
|
||||||
logger.debug("Attempting prompt-based fallback using fallback models")
|
try:
|
||||||
failed_tool_call_name = code
|
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||||
for fallback_model in self.fallback_tool_models:
|
simple_model = initialize_llm(
|
||||||
try:
|
fallback_model["provider"], fallback_model["model"]
|
||||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
)
|
||||||
simple_model = initialize_llm(
|
binded_model = simple_model.bind_tools(
|
||||||
fallback_model["provider"], fallback_model["model"]
|
[self.current_tool_to_bind],
|
||||||
)
|
tool_choice=self.current_failing_tool_name,
|
||||||
tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name)
|
)
|
||||||
binded_model = simple_model.bind_tools(
|
retry_model = binded_model.with_retry(
|
||||||
[tool_to_bind], tool_choice=failed_tool_call_name
|
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||||
)
|
)
|
||||||
# retry_model = binded_model.with_retry(
|
response = retry_model.invoke(self.current_failing_tool_name)
|
||||||
# stop_after_attempt=RETRY_FALLBACK_COUNT
|
cpm(f"response={response}")
|
||||||
# )
|
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||||
response = binded_model.invoke(code)
|
tool_call = self.base_message_to_tool_call_dict(response)
|
||||||
cpm(f"response={response}")
|
if tool_call:
|
||||||
|
result = self.invoke_prompt_tool_call(tool_call)
|
||||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
cpm(f"result={result}")
|
||||||
|
|
||||||
tool_call = self.base_message_to_tool_call_dict(response)
|
|
||||||
if tool_call:
|
|
||||||
result = self.invoke_prompt_tool_call(tool_call)
|
|
||||||
cpm(f"result={result}")
|
|
||||||
logger.debug(
|
|
||||||
"Prompt-based fallback executed successfully with model: "
|
|
||||||
+ fallback_model["model"]
|
|
||||||
)
|
|
||||||
self.reset_fallback_handler()
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
cpm(
|
|
||||||
response.content if hasattr(response, "content") else response,
|
|
||||||
title="Fallback Model Response: " + fallback_model["model"],
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, KeyboardInterrupt):
|
|
||||||
raise
|
|
||||||
logger.error(
|
|
||||||
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
|
|
||||||
)
|
|
||||||
raise Exception("All prompt-based fallback models failed")
|
|
||||||
|
|
||||||
def attempt_fallback_function(self, code: str, logger, agent):
|
|
||||||
"""
|
|
||||||
Attempt a function-calling fallback by iterating over fallback models and invoking the provided code.
|
|
||||||
|
|
||||||
This method tries each fallback model (with retry logic configured) until one successfully executes the code.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code (str): The tool code to invoke via fallback.
|
|
||||||
logger: Logger instance for logging messages.
|
|
||||||
agent: The agent instance to update with the new model upon success.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The response from the fallback model invocation.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If all function-calling fallback models fail.
|
|
||||||
"""
|
|
||||||
logger.debug("Attempting function-calling fallback using fallback models")
|
|
||||||
failed_tool_call_name = code
|
|
||||||
for fallback_model in self.fallback_tool_models:
|
|
||||||
try:
|
|
||||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
|
||||||
simple_model = initialize_llm(
|
|
||||||
fallback_model["provider"], fallback_model["model"]
|
|
||||||
)
|
|
||||||
tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name)
|
|
||||||
binded_model = simple_model.bind_tools(
|
|
||||||
[tool_to_bind], tool_choice=failed_tool_call_name
|
|
||||||
)
|
|
||||||
retry_model = binded_model.with_retry(
|
|
||||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
|
||||||
)
|
|
||||||
response = retry_model.invoke(code)
|
|
||||||
cpm(f"response={response}")
|
|
||||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
|
||||||
self.reset_fallback_handler()
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Function-calling fallback executed successfully with model: "
|
"Prompt-based fallback executed successfully with model: "
|
||||||
+ fallback_model["model"]
|
+ fallback_model["model"]
|
||||||
)
|
)
|
||||||
|
self.reset_fallback_handler()
|
||||||
|
return result
|
||||||
|
else:
|
||||||
cpm(
|
cpm(
|
||||||
response.content if hasattr(response, "content") else response,
|
response.content if hasattr(response, "content") else response,
|
||||||
title="Fallback Model Response: " + fallback_model["model"],
|
title="Fallback Model Response: " + fallback_model["model"],
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, KeyboardInterrupt):
|
if isinstance(e, KeyboardInterrupt):
|
||||||
raise
|
raise
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
|
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
|
||||||
)
|
)
|
||||||
raise Exception("All function-calling fallback models failed")
|
return None
|
||||||
|
|
||||||
|
def attempt_fallback_function(self, fallback_model):
|
||||||
|
"""
|
||||||
|
Attempt a function-calling fallback by invoking the current failing tool with the given fallback model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fallback_model (dict): The fallback model to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response from the fallback model invocation, or None if failed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||||
|
simple_model = initialize_llm(
|
||||||
|
fallback_model["provider"], fallback_model["model"]
|
||||||
|
)
|
||||||
|
binded_model = simple_model.bind_tools(
|
||||||
|
[self.current_tool_to_bind],
|
||||||
|
tool_choice=self.current_failing_tool_name,
|
||||||
|
)
|
||||||
|
retry_model = binded_model.with_retry(
|
||||||
|
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||||
|
)
|
||||||
|
response = retry_model.invoke(self.current_failing_tool_name)
|
||||||
|
cpm(f"response={response}")
|
||||||
|
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||||
|
logger.debug(
|
||||||
|
"Function-calling fallback executed successfully with model: "
|
||||||
|
+ fallback_model["model"]
|
||||||
|
)
|
||||||
|
cpm(
|
||||||
|
response.content if hasattr(response, "content") else response,
|
||||||
|
title="Fallback Model Response: " + fallback_model["model"],
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, KeyboardInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error(
|
||||||
|
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
def invoke_prompt_tool_call(self, tool_call_request: dict):
|
def invoke_prompt_tool_call(self, tool_call_request: dict):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
|
||||||
class PrettyHandler(logging.Handler):
|
class PrettyHandler(logging.Handler):
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from ra_aid.tools import (
|
from ra_aid.tools import (
|
||||||
ask_expert,
|
ask_expert,
|
||||||
ask_human,
|
ask_human,
|
||||||
|
|
@ -61,11 +63,13 @@ def get_read_only_tools(
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
def get_all_tools_simple():
|
def get_all_tools_simple():
|
||||||
"""Return a list containing all available tools using existing group methods."""
|
"""Return a list containing all available tools using existing group methods."""
|
||||||
return get_all_tools()
|
return get_all_tools()
|
||||||
|
|
||||||
def get_all_tools():
|
|
||||||
|
def get_all_tools() -> list[BaseTool]:
|
||||||
"""Return a list containing all available tools from different groups."""
|
"""Return a list containing all available tools from different groups."""
|
||||||
all_tools = []
|
all_tools = []
|
||||||
all_tools.extend(get_read_only_tools())
|
all_tools.extend(get_read_only_tools())
|
||||||
|
|
@ -176,7 +180,7 @@ def get_implementation_tools(
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
def get_web_research_tools(expert_enabled: bool = True) -> list:
|
def get_web_research_tools(expert_enabled: bool = True):
|
||||||
"""Get the list of tools available for web research.
|
"""Get the list of tools available for web research.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -196,9 +200,7 @@ def get_web_research_tools(expert_enabled: bool = True) -> list:
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
def get_chat_tools(
|
def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False):
|
||||||
expert_enabled: bool = True, web_research_enabled: bool = False
|
|
||||||
) -> list:
|
|
||||||
"""Get the list of tools available in chat mode.
|
"""Get the list of tools available in chat mode.
|
||||||
|
|
||||||
Chat mode includes research and implementation capabilities but excludes
|
Chat mode includes research and implementation capabilities but excludes
|
||||||
|
|
|
||||||
|
|
@ -276,11 +276,19 @@ def test_get_model_token_limit_planner(mock_memory):
|
||||||
token_limit = get_model_token_limit(config, "planner")
|
token_limit = get_model_token_limit(config, "planner")
|
||||||
assert token_limit == 120000
|
assert token_limit == 120000
|
||||||
|
|
||||||
|
|
||||||
# New tests for private helper methods in agent_utils.py
|
# New tests for private helper methods in agent_utils.py
|
||||||
|
|
||||||
|
|
||||||
def test_setup_and_restore_interrupt_handling():
|
def test_setup_and_restore_interrupt_handling():
|
||||||
import signal, threading
|
import signal
|
||||||
from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt
|
|
||||||
|
from ra_aid.agent_utils import (
|
||||||
|
_request_interrupt,
|
||||||
|
_restore_interrupt_handling,
|
||||||
|
_setup_interrupt_handling,
|
||||||
|
)
|
||||||
|
|
||||||
original_handler = signal.getsignal(signal.SIGINT)
|
original_handler = signal.getsignal(signal.SIGINT)
|
||||||
handler = _setup_interrupt_handling()
|
handler = _setup_interrupt_handling()
|
||||||
# Verify the SIGINT handler is set to _request_interrupt
|
# Verify the SIGINT handler is set to _request_interrupt
|
||||||
|
|
@ -289,66 +297,93 @@ def test_setup_and_restore_interrupt_handling():
|
||||||
# Verify the SIGINT handler is restored to the original
|
# Verify the SIGINT handler is restored to the original
|
||||||
assert signal.getsignal(signal.SIGINT) == original_handler
|
assert signal.getsignal(signal.SIGINT) == original_handler
|
||||||
|
|
||||||
|
|
||||||
def test_increment_and_decrement_agent_depth():
|
def test_increment_and_decrement_agent_depth():
|
||||||
from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory
|
from ra_aid.agent_utils import (
|
||||||
|
_decrement_agent_depth,
|
||||||
|
_global_memory,
|
||||||
|
_increment_agent_depth,
|
||||||
|
)
|
||||||
|
|
||||||
_global_memory["agent_depth"] = 10
|
_global_memory["agent_depth"] = 10
|
||||||
_increment_agent_depth()
|
_increment_agent_depth()
|
||||||
assert _global_memory["agent_depth"] == 11
|
assert _global_memory["agent_depth"] == 11
|
||||||
_decrement_agent_depth()
|
_decrement_agent_depth()
|
||||||
assert _global_memory["agent_depth"] == 10
|
assert _global_memory["agent_depth"] == 10
|
||||||
|
|
||||||
|
|
||||||
def test_run_agent_stream(monkeypatch):
|
def test_run_agent_stream(monkeypatch):
|
||||||
from ra_aid.agent_utils import _run_agent_stream, _global_memory
|
from ra_aid.agent_utils import _global_memory, _run_agent_stream
|
||||||
|
|
||||||
# Create a dummy agent that yields one chunk
|
# Create a dummy agent that yields one chunk
|
||||||
class DummyAgent:
|
class DummyAgent:
|
||||||
def stream(self, msg, cfg):
|
def stream(self, msg, cfg):
|
||||||
yield {"content": "chunk1"}
|
yield {"content": "chunk1"}
|
||||||
|
|
||||||
dummy_agent = DummyAgent()
|
dummy_agent = DummyAgent()
|
||||||
# Set flags so that _run_agent_stream will reset them
|
# Set flags so that _run_agent_stream will reset them
|
||||||
_global_memory["plan_completed"] = True
|
_global_memory["plan_completed"] = True
|
||||||
_global_memory["task_completed"] = True
|
_global_memory["task_completed"] = True
|
||||||
_global_memory["completion_message"] = "existing"
|
_global_memory["completion_message"] = "existing"
|
||||||
call_flag = {"called": False}
|
call_flag = {"called": False}
|
||||||
|
|
||||||
def fake_print_agent_output(chunk):
|
def fake_print_agent_output(chunk):
|
||||||
call_flag["called"] = True
|
call_flag["called"] = True
|
||||||
monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output)
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.agent_utils.print_agent_output", fake_print_agent_output
|
||||||
|
)
|
||||||
_run_agent_stream(dummy_agent, "dummy prompt", {})
|
_run_agent_stream(dummy_agent, "dummy prompt", {})
|
||||||
assert call_flag["called"]
|
assert call_flag["called"]
|
||||||
assert _global_memory["plan_completed"] is False
|
assert _global_memory["plan_completed"] is False
|
||||||
assert _global_memory["task_completed"] is False
|
assert _global_memory["task_completed"] is False
|
||||||
assert _global_memory["completion_message"] == ""
|
assert _global_memory["completion_message"] == ""
|
||||||
|
|
||||||
|
|
||||||
def test_execute_test_command_wrapper(monkeypatch):
|
def test_execute_test_command_wrapper(monkeypatch):
|
||||||
from ra_aid.agent_utils import _execute_test_command_wrapper
|
from ra_aid.agent_utils import _execute_test_command_wrapper
|
||||||
|
|
||||||
# Patch execute_test_command to return a testable tuple
|
# Patch execute_test_command to return a testable tuple
|
||||||
def fake_execute(config, orig, tests, auto):
|
def fake_execute(config, orig, tests, auto):
|
||||||
return (True, "new prompt", auto, tests + 1)
|
return (True, "new prompt", auto, tests + 1)
|
||||||
|
|
||||||
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
|
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
|
||||||
result = _execute_test_command_wrapper("orig", {}, 0, False)
|
result = _execute_test_command_wrapper("orig", {}, 0, False)
|
||||||
assert result == (True, "new prompt", False, 1)
|
assert result == (True, "new prompt", False, 1)
|
||||||
|
|
||||||
|
|
||||||
def test_handle_api_error_valueerror():
|
def test_handle_api_error_valueerror():
|
||||||
from ra_aid.agent_utils import _handle_api_error
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ra_aid.agent_utils import _handle_api_error
|
||||||
|
|
||||||
# ValueError not containing "code" or "429" should be re-raised
|
# ValueError not containing "code" or "429" should be re-raised
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
_handle_api_error(ValueError("some error"), 0, 5, 1)
|
_handle_api_error(ValueError("some error"), 0, 5, 1)
|
||||||
|
|
||||||
|
|
||||||
def test_handle_api_error_max_retries():
|
def test_handle_api_error_max_retries():
|
||||||
from ra_aid.agent_utils import _handle_api_error
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ra_aid.agent_utils import _handle_api_error
|
||||||
|
|
||||||
# When attempt reaches max retries, a RuntimeError should be raised
|
# When attempt reaches max retries, a RuntimeError should be raised
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
_handle_api_error(Exception("error code 429"), 4, 5, 1)
|
_handle_api_error(Exception("error code 429"), 4, 5, 1)
|
||||||
|
|
||||||
|
|
||||||
def test_handle_api_error_retry(monkeypatch):
|
def test_handle_api_error_retry(monkeypatch):
|
||||||
from ra_aid.agent_utils import _handle_api_error
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from ra_aid.agent_utils import _handle_api_error
|
||||||
|
|
||||||
# Patch time.monotonic and time.sleep to simulate immediate delay expiration
|
# Patch time.monotonic and time.sleep to simulate immediate delay expiration
|
||||||
fake_time = [0]
|
fake_time = [0]
|
||||||
|
|
||||||
def fake_monotonic():
|
def fake_monotonic():
|
||||||
fake_time[0] += 0.5
|
fake_time[0] += 0.5
|
||||||
return fake_time[0]
|
return fake_time[0]
|
||||||
|
|
||||||
monkeypatch.setattr(time, "monotonic", fake_monotonic)
|
monkeypatch.setattr(time, "monotonic", fake_monotonic)
|
||||||
monkeypatch.setattr(time, "sleep", lambda s: None)
|
monkeypatch.setattr(time, "sleep", lambda s: None)
|
||||||
# Should not raise error when attempt is lower than max retries
|
# Should not raise error when attempt is lower than max retries
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,41 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from ra_aid.fallback_handler import FallbackHandler
|
from ra_aid.fallback_handler import FallbackHandler
|
||||||
|
|
||||||
|
|
||||||
class DummyLogger:
|
class DummyLogger:
|
||||||
def debug(self, msg):
|
def debug(self, msg):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def error(self, msg):
|
def error(self, msg):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DummyAgent:
|
class DummyAgent:
|
||||||
provider = "openai"
|
provider = "openai"
|
||||||
tools = []
|
tools = []
|
||||||
model = None
|
model = None
|
||||||
|
|
||||||
|
|
||||||
class TestFallbackHandler(unittest.TestCase):
|
class TestFallbackHandler(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.config = {"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"}
|
self.config = {
|
||||||
|
"max_tool_failures": 2,
|
||||||
|
"fallback_tool_models": "dummy-fallback-model",
|
||||||
|
}
|
||||||
self.fallback_handler = FallbackHandler(self.config)
|
self.fallback_handler = FallbackHandler(self.config)
|
||||||
self.logger = DummyLogger()
|
self.logger = DummyLogger()
|
||||||
self.agent = DummyAgent()
|
self.agent = DummyAgent()
|
||||||
|
|
||||||
def test_handle_failure_increments_counter(self):
|
def test_handle_failure_increments_counter(self):
|
||||||
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
|
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
|
||||||
self.fallback_handler.handle_failure("dummy_call()", Exception("Test error"), self.logger, self.agent)
|
self.fallback_handler.handle_failure(
|
||||||
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1)
|
"dummy_call()", Exception("Test error"), self.logger, self.agent
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
self.fallback_handler.tool_failure_consecutive_failures,
|
||||||
|
initial_failures + 1,
|
||||||
|
)
|
||||||
|
|
||||||
def test_attempt_fallback_resets_counter(self):
|
def test_attempt_fallback_resets_counter(self):
|
||||||
# Monkey-patch dummy functions for fallback components
|
# Monkey-patch dummy functions for fallback components
|
||||||
|
|
@ -30,6 +43,7 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
class DummyModel:
|
class DummyModel:
|
||||||
def bind_tools(self, tools, tool_choice):
|
def bind_tools(self, tools, tool_choice):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return DummyModel()
|
return DummyModel()
|
||||||
|
|
||||||
def dummy_merge_chat_history():
|
def dummy_merge_chat_history():
|
||||||
|
|
@ -39,6 +53,7 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
import ra_aid.llm as llm
|
import ra_aid.llm as llm
|
||||||
|
|
||||||
original_initialize = llm.initialize_llm
|
original_initialize = llm.initialize_llm
|
||||||
original_merge = llm.merge_chat_history
|
original_merge = llm.merge_chat_history
|
||||||
original_validate = llm.validate_provider_env
|
original_validate = llm.validate_provider_env
|
||||||
|
|
@ -47,12 +62,15 @@ class TestFallbackHandler(unittest.TestCase):
|
||||||
llm.validate_provider_env = dummy_validate_provider_env
|
llm.validate_provider_env = dummy_validate_provider_env
|
||||||
|
|
||||||
self.fallback_handler.tool_failure_consecutive_failures = 2
|
self.fallback_handler.tool_failure_consecutive_failures = 2
|
||||||
self.fallback_handler.attempt_fallback("dummy_tool_call()", self.logger, self.agent)
|
self.fallback_handler.attempt_fallback(
|
||||||
|
"dummy_tool_call()", self.logger, self.agent
|
||||||
|
)
|
||||||
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
|
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
|
||||||
|
|
||||||
llm.initialize_llm = original_initialize
|
llm.initialize_llm = original_initialize
|
||||||
llm.merge_chat_history = original_merge
|
llm.merge_chat_history = original_merge
|
||||||
llm.validate_provider_env = original_validate
|
llm.validate_provider_env = original_validate
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue