From 67ecf72a6c15c89f4ac4199d5a8710d80c1eb10b Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Feb 2025 18:35:34 -0800 Subject: [PATCH] feat(fallback): implement fallback handler for tool execution errors to enhance error resilience and user experience refactor(fallback): streamline fallback model selection and invocation process for improved maintainability fix(config): reduce maximum tool failures from 3 to 2 to tighten error handling thresholds style(console): improve error message formatting and logging for better clarity and debugging chore(main): remove redundant fallback tool model handling from main function to simplify configuration management --- ra_aid/__main__.py | 18 --- ra_aid/agent_utils.py | 69 +++++++--- ra_aid/config.py | 2 +- ra_aid/console/output.py | 22 +++- ra_aid/exceptions.py | 5 +- ra_aid/fallback_handler.py | 261 +++++++++++++++++++++++-------------- ra_aid/tool_configs.py | 19 ++- 7 files changed, 260 insertions(+), 136 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 654dd60..e027a08 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -427,15 +427,6 @@ def main(): _global_memory["config"]["planner_model"] = args.planner_model or args.model _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool - _global_memory["config"]["fallback_tool_models"] = ( - [ - model.strip() - for model in args.fallback_tool_models.split(",") - if model.strip() - ] - if args.fallback_tool_models - else [] - ) # Store research config with fallback to base values _global_memory["config"]["research_provider"] = ( @@ -445,15 +436,6 @@ def main(): # Store fallback tool configuration _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool - _global_memory["config"]["fallback_tool_models"] = ( - [ - model.strip() - for model in args.fallback_tool_models.split(",") - if model.strip() - ] - if args.fallback_tool_models - else [] - ) # Run research stage print_stage_header("Research Stage") diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 678ce73..cf85a0c 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -16,7 +16,6 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import ( BaseMessage, HumanMessage, - InvalidToolCall, trim_messages, ) from langchain_core.tools import tool @@ -339,9 +338,6 @@ def run_research_agent( if memory is None: memory = MemorySaver() - if thread_id is None: - thread_id = str(uuid.uuid4()) - tools = get_research_tools( research_only=research_only, expert_enabled=expert_enabled, @@ -413,7 +409,8 @@ def run_research_agent( if agent is not None: logger.debug("Research agent completed successfully") - _result = run_agent_with_retry(agent, prompt, run_config) + fallback_handler = FallbackHandler(config, tools) + _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) if _result: # Log research completion log_work_event(f"Completed research phase for: {base_task_or_query}") @@ -529,7 +526,8 @@ def run_web_research_agent( console.print(Panel(Markdown(console_message), title="🔬 Researching...")) logger.debug("Web research agent completed successfully") - _result = run_agent_with_retry(agent, prompt, run_config) + fallback_handler = FallbackHandler(config, tools) + _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) if _result: # Log web research completion log_work_event(f"Completed web research phase for: {query}") @@ -634,7 +632,10 @@ def run_planning_agent( try: print_stage_header("Planning Stage") logger.debug("Planning agent completed successfully") - _result = run_agent_with_retry(agent, planning_prompt, run_config) + fallback_handler = FallbackHandler(config, tools) + _result = run_agent_with_retry( + agent, planning_prompt, run_config, fallback_handler + ) if _result: # Log planning completion log_work_event(f"Completed planning phase for: {base_task}") @@ -739,7 +740,8 @@ def run_task_implementation_agent( try: logger.debug("Implementation agent completed successfully") - _result = run_agent_with_retry(agent, prompt, run_config) + fallback_handler = FallbackHandler(config, tools) + _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) if _result: # Log task implementation completion log_work_event(f"Completed implementation of task: {task}") @@ -805,7 +807,7 @@ def _decrement_agent_depth(): _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 -def _run_agent_stream(agent, prompt, config): +def _run_agent_stream(agent: CompiledGraph, prompt: str, config: dict): for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config): logger.debug("Agent output: %s", chunk) check_interrupt() @@ -840,7 +842,9 @@ def _handle_api_error(e, attempt, max_retries, base_delay): time.sleep(0.1) -def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: +def run_agent_with_retry( + agent, prompt: str, config: dict, fallback_handler: FallbackHandler +) -> Optional[str]: """Run an agent with retry logic for API errors.""" logger.debug("Running agent with prompt length: %d", len(prompt)) original_handler = _setup_interrupt_handling() @@ -850,7 +854,6 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) auto_test = config.get("auto_test", False) original_prompt = prompt - fallback_handler = FallbackHandler(config) with InterruptibleSection(): try: @@ -872,8 +875,13 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: continue logger.debug("Agent run completed successfully") return "Agent run completed successfully" - except (ToolExecutionError, InvalidToolCall) as e: - _handle_tool_execution_error(fallback_handler, agent, e) + except ToolExecutionError as e: + fallback_response = _handle_tool_execution_error( + fallback_handler, agent, e + ) + if fallback_response: + prompt = original_prompt + "\n" + fallback_response + continue except (KeyboardInterrupt, AgentInterrupt): raise except ( @@ -892,6 +900,37 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: def _handle_tool_execution_error( fallback_handler: FallbackHandler, agent: CiaynAgent | CompiledGraph, - error: ToolExecutionError | InvalidToolCall, + error: ToolExecutionError, ): - fallback_handler.handle_failure("Tool execution error", error, agent) + logger.debug("Entering _handle_tool_execution_error with error: %s", error) + if error.tool_name: + failed_tool_call_name = error.tool_name + logger.debug( + "Extracted failed_tool_call_name from error.tool_name: %s", + failed_tool_call_name, + ) + else: + import re + + msg = str(error) + logger.debug("Error message: %s", msg) + match = re.search(r"name=['\"](\w+)['\"]", msg) + if match: + failed_tool_call_name = match.group(1) + logger.debug( + "Extracted failed_tool_call_name using regex: %s", failed_tool_call_name + ) + else: + failed_tool_call_name = "Tool execution error" + logger.debug( + "Defaulting failed_tool_call_name to: %s", failed_tool_call_name + ) + logger.debug( + "Calling fallback_handler.handle_failure with failed_tool_call_name: %s", + failed_tool_call_name, + ) + fallback_response = fallback_handler.handle_failure( + failed_tool_call_name, error, agent + ) + logger.debug("Fallback response received: %s", fallback_response) + return fallback_response diff --git a/ra_aid/config.py b/ra_aid/config.py index e85cb12..54d7995 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -2,7 +2,7 @@ DEFAULT_RECURSION_LIMIT = 100 DEFAULT_MAX_TEST_CMD_RETRIES = 3 -DEFAULT_MAX_TOOL_FAILURES = 3 +DEFAULT_MAX_TOOL_FAILURES = 2 FALLBACK_TOOL_MODEL_LIMIT = 5 RETRY_FALLBACK_COUNT = 3 RETRY_FALLBACK_DELAY = 2 diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index 8b64142..aad96e6 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -1,9 +1,11 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from langchain_core.messages import AIMessage from rich.markdown import Markdown from rich.panel import Panel +from ra_aid.exceptions import ToolExecutionError + # Import shared console instance from .formatting import console @@ -33,10 +35,26 @@ def print_agent_output(chunk: Dict[str, Any]) -> None: elif "tools" in chunk and "messages" in chunk["tools"]: for msg in chunk["tools"]["messages"]: if msg.status == "error" and msg.content: + err_msg = msg.content.strip() console.print( Panel( - Markdown(msg.content.strip()), + Markdown(err_msg), title="❌ Tool Error", border_style="red bold", ) ) + tool_name = getattr(msg, "name", None) + raise ToolExecutionError(err_msg, tool_name=tool_name) + + +def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None: + """ + Print a message using a Panel with Markdown formatting. + + Args: + message (str): The message content to display. + title (Optional[str]): An optional title for the panel. + border_style (str): Border style for the panel. + """ + + console.print(Panel(Markdown(message), title=title, border_style=border_style)) diff --git a/ra_aid/exceptions.py b/ra_aid/exceptions.py index 696b47e..d8bc532 100644 --- a/ra_aid/exceptions.py +++ b/ra_aid/exceptions.py @@ -17,5 +17,6 @@ class ToolExecutionError(Exception): This exception is used to distinguish tool execution failures from other types of errors in the agent system. """ - - pass + def __init__(self, message, tool_name=None): + super().__init__(message) + self.tool_name = tool_name diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 2b248a5..8dd459c 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -1,13 +1,24 @@ +from typing import Dict +from langchain_core.tools import BaseTool +from langgraph.graph.graph import CompiledGraph +from langgraph.graph.message import BaseMessage + +from ra_aid.console.output import cpm +import json + +from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import ( DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, ) +from ra_aid.logging_config import get_logger from ra_aid.tool_leaderboard import supported_top_tool_models from rich.console import Console -from rich.markdown import Markdown -from rich.panel import Panel -from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env +from ra_aid.llm import initialize_llm, validate_provider_env + +# from langgraph.graph.message import BaseMessage, BaseMessageChunk +# from langgraph.prebuilt import ToolNode logger = get_logger(__name__) @@ -22,18 +33,21 @@ class FallbackHandler: counters when a tool call succeeds. """ - def __init__(self, config): + def __init__(self, config, tools): """ - Initialize the FallbackHandler with the given configuration. + Initialize the FallbackHandler with the given configuration and tools. Args: config (dict): Configuration dictionary that may include fallback settings. + tools (list): List of available tools. """ self.config = config + self.tools: list[BaseTool] = tools self.fallback_enabled = config.get("fallback_tool_enabled", True) self.fallback_tool_models = self._load_fallback_tool_models(config) self.tool_failure_consecutive_failures = 0 self.tool_failure_used_fallbacks = set() + self.console = Console() def _load_fallback_tool_models(self, config): """ @@ -49,46 +63,37 @@ class FallbackHandler: Returns: list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model. """ - fallback_tool_models_config = config.get("fallback_tool_models") - if fallback_tool_models_config: - # Assume comma-separated model names; wrap each in a dict with default type "prompt" - models = [] - for m in [ - x.strip() for x in fallback_tool_models_config.split(",") if x.strip() - ]: - models.append({"model": m, "type": "prompt"}) - return models - else: - console = Console() - supported = [] - skipped = [] - for item in supported_top_tool_models: - provider = item.get("provider") - model_name = item.get("model") - if validate_provider_env(provider): - supported.append(item) - if len(supported) == FALLBACK_TOOL_MODEL_LIMIT: - break - else: - skipped.append(model_name) - final_models = [] - for item in supported: - if "type" not in item: - item["type"] = "prompt" - item["model"] = item["model"].lower() - final_models.append(item) - message = "Fallback models selected: " + ", ".join( - [m["model"] for m in final_models] + supported = [] + skipped = [] + for item in supported_top_tool_models: + provider = item.get("provider") + model_name = item.get("model") + if validate_provider_env(provider): + supported.append(item) + if len(supported) == FALLBACK_TOOL_MODEL_LIMIT: + break + else: + skipped.append(model_name) + final_models = [] + for item in supported: + if "type" not in item: + item["type"] = "prompt" + item["model"] = item["model"].lower() + final_models.append(item) + message = "Fallback models selected: " + ", ".join( + [m["model"] for m in final_models] + ) + if skipped: + message += ( + "\nSkipped top tool calling models due to missing provider ENV API keys: " + + ", ".join(skipped) ) - if skipped: - message += ( - "\nSkipped top tool calling models due to missing provider ENV API keys: " - + ", ".join(skipped) - ) - console.print(Panel(Markdown(message), title="Fallback Models")) - return final_models + cpm(message, title="Fallback Models") + return final_models - def handle_failure(self, code: str, error: Exception, agent): + def handle_failure( + self, code: str, error: Exception, agent: CiaynAgent | CompiledGraph + ): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. @@ -114,7 +119,7 @@ class FallbackHandler: logger.debug( "_handle_tool_failure: threshold reached, invoking fallback mechanism." ) - self.attempt_fallback(code, logger, agent) + return self.attempt_fallback(code, logger, agent) def attempt_fallback(self, code: str, logger, agent): """ @@ -127,17 +132,13 @@ class FallbackHandler: """ logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") fallback_model = self.fallback_tool_models[0] - failed_tool_call_name = code.split("(")[0].strip() + failed_tool_call_name = code logger.error( f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" ) - Console().print( - Panel( - Markdown( - f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}." - ), - title="Fallback Notification", - ) + cpm( + f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}.", + title="Fallback Notification", ) if fallback_model.get("type", "prompt").lower() == "fc": self.attempt_fallback_function(code, logger, agent) @@ -151,6 +152,30 @@ class FallbackHandler: self.tool_failure_consecutive_failures = 0 self.tool_failure_used_fallbacks.clear() + def _find_tool_to_bind(self, agent, failed_tool_call_name): + logger.debug(f"failed_tool_call_name={failed_tool_call_name}") + tool_to_bind = None + if hasattr(agent, "tools"): + tool_to_bind = next( + (t for t in agent.tools if t.func.__name__ == failed_tool_call_name), + None, + ) + if tool_to_bind is None: + from ra_aid.tool_configs import get_all_tools + + all_tools = get_all_tools() + tool_to_bind = next( + (t for t in all_tools if t.func.__name__ == failed_tool_call_name), + None, + ) + if tool_to_bind is None: + available = [t.func.__name__ for t in get_all_tools()] + logger.debug( + f"Failed to find tool: {failed_tool_call_name}. Available tools: {available}" + ) + raise Exception(f"Tool {failed_tool_call_name} not found in all tools.") + return tool_to_bind + def attempt_fallback_prompt(self, code: str, logger, agent): """ Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. @@ -169,43 +194,41 @@ class FallbackHandler: Exception: If all prompt-based fallback models fail. """ logger.debug("Attempting prompt-based fallback using fallback models") - failed_tool_call_name = code.split("(")[0].strip() + failed_tool_call_name = code for fallback_model in self.fallback_tool_models: try: logger.debug(f"Trying fallback model: {fallback_model['model']}") simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) - tool_to_bind = next( - ( - t - for t in agent.tools - if t.func.__name__ == failed_tool_call_name - ), - None, - ) - if tool_to_bind is None: - logger.debug( - f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}" - ) - raise Exception( - f"Tool {failed_tool_call_name} not found in agent.tools" - ) + tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name) binded_model = simple_model.bind_tools( [tool_to_bind], tool_choice=failed_tool_call_name ) - retry_model = binded_model.with_retry( - stop_after_attempt=RETRY_FALLBACK_COUNT - ) - response = retry_model.invoke(code) + # retry_model = binded_model.with_retry( + # stop_after_attempt=RETRY_FALLBACK_COUNT + # ) + response = binded_model.invoke(code) + cpm(f"response={response}") + self.tool_failure_used_fallbacks.add(fallback_model["model"]) - agent.model = retry_model - self.reset_fallback_handler() - logger.debug( - "Prompt-based fallback executed successfully with model: " - + fallback_model["model"] - ) - return response + + tool_call = self.base_message_to_tool_call_dict(response) + if tool_call: + result = self.invoke_prompt_tool_call(tool_call) + cpm(f"result={result}") + logger.debug( + "Prompt-based fallback executed successfully with model: " + + fallback_model["model"] + ) + self.reset_fallback_handler() + return result + else: + cpm( + response.content if hasattr(response, "content") else response, + title="Fallback Model Response: " + fallback_model["model"], + ) + return response except Exception as e: if isinstance(e, KeyboardInterrupt): raise @@ -232,28 +255,14 @@ class FallbackHandler: Exception: If all function-calling fallback models fail. """ logger.debug("Attempting function-calling fallback using fallback models") - failed_tool_call_name = code.split("(")[0].strip() + failed_tool_call_name = code for fallback_model in self.fallback_tool_models: try: logger.debug(f"Trying fallback model: {fallback_model['model']}") simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) - tool_to_bind = next( - ( - t - for t in agent.tools - if t.func.__name__ == failed_tool_call_name - ), - None, - ) - if tool_to_bind is None: - logger.debug( - f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}" - ) - raise Exception( - f"Tool {failed_tool_call_name} not found in agent.tools" - ) + tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name) binded_model = simple_model.bind_tools( [tool_to_bind], tool_choice=failed_tool_call_name ) @@ -261,13 +270,18 @@ class FallbackHandler: stop_after_attempt=RETRY_FALLBACK_COUNT ) response = retry_model.invoke(code) + cpm(f"response={response}") self.tool_failure_used_fallbacks.add(fallback_model["model"]) - agent.model = retry_model self.reset_fallback_handler() logger.debug( "Function-calling fallback executed successfully with model: " + fallback_model["model"] ) + + cpm( + response.content if hasattr(response, "content") else response, + title="Fallback Model Response: " + fallback_model["model"], + ) return response except Exception as e: if isinstance(e, KeyboardInterrupt): @@ -276,3 +290,58 @@ class FallbackHandler: f"Function-calling fallback with model {fallback_model['model']} failed: {e}" ) raise Exception("All function-calling fallback models failed") + + def invoke_prompt_tool_call(self, tool_call_request: dict): + """ + Invoke a tool call from a prompt-based fallback response. + + Args: + tool_call_request (dict): The tool call request containing keys 'type', 'name', and 'arguments'. + + Returns: + The result of invoking the tool. + """ + tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools} + name = tool_call_request["name"] + arguments = tool_call_request["arguments"] + # return tool_name_to_tool[name].invoke(arguments) + # tool_call_dict = {"arguments": arguments} + return tool_name_to_tool[name].invoke(arguments) + + def base_message_to_tool_call_dict(self, response: BaseMessage): + """ + Extracts a tool call dictionary from a fallback response. + + Args: + response: The response object containing tool call data. + + Returns: + A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found, + otherwise None. + """ + tool_calls = None + if hasattr(response, "additional_kwargs") and response.additional_kwargs.get( + "tool_calls" + ): + tool_calls = response.additional_kwargs.get("tool_calls") + elif hasattr(response, "tool_calls"): + tool_calls = response.tool_calls + elif isinstance(response, dict) and response.get("additional_kwargs", {}).get( + "tool_calls" + ): + tool_calls = response.get("additional_kwargs").get("tool_calls") + if tool_calls: + if len(tool_calls) > 1: + logger.warning("Multiple tool calls detected, using the first one") + tool_call = tool_calls[0] + return { + "id": tool_call["id"], + "type": tool_call["type"], + "name": tool_call["function"]["name"], + "arguments": ( + json.loads(tool_call["function"]["arguments"]) + if isinstance(tool_call["function"]["arguments"], str) + else tool_call["function"]["arguments"] + ), + } + return None diff --git a/ra_aid/tool_configs.py b/ra_aid/tool_configs.py index e4042f1..8fce691 100644 --- a/ra_aid/tool_configs.py +++ b/ra_aid/tool_configs.py @@ -28,7 +28,7 @@ from ra_aid.tools.write_file import write_file_tool # Read-only tools that don't modify system state def get_read_only_tools( human_interaction: bool = False, web_research_enabled: bool = False -) -> list: +): """Get the list of read-only tools, optionally including human interaction tools. Args: @@ -61,6 +61,21 @@ def get_read_only_tools( return tools +def get_all_tools_simple(): + """Return a list containing all available tools using existing group methods.""" + return get_all_tools() + +def get_all_tools(): + """Return a list containing all available tools from different groups.""" + all_tools = [] + all_tools.extend(get_read_only_tools()) + all_tools.extend(MODIFICATION_TOOLS) + all_tools.extend(EXPERT_TOOLS) + all_tools.extend(RESEARCH_TOOLS) + all_tools.extend(get_web_research_tools()) + all_tools.extend(get_chat_tools()) + return all_tools + # Define constant tool groups READ_ONLY_TOOLS = get_read_only_tools() @@ -81,7 +96,7 @@ def get_research_tools( expert_enabled: bool = True, human_interaction: bool = False, web_research_enabled: bool = False, -) -> list: +): """Get the list of research tools based on mode and whether expert is enabled. Args: