diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 3aea6a8..ad1adf5 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -878,7 +878,7 @@ def run_agent_with_retry( except ToolExecutionError as e: fallback_response = fallback_handler.handle_failure(e, agent) if fallback_response: - prompt = original_prompt + "\n" + fallback_response + prompt = original_prompt + "\n" + str(fallback_response) continue except (KeyboardInterrupt, AgentInterrupt): raise diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 5e80826..4ae2016 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -4,6 +4,7 @@ import re from langchain_core.tools import BaseTool from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import BaseMessage +from langchain_core.messages import SystemMessage, HumanMessage from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import ( @@ -45,7 +46,7 @@ class FallbackHandler: self.fallback_tool_models = self._load_fallback_tool_models(config) self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) self.tool_failure_consecutive_failures = 0 - self.failed_messages = set() + self.failed_messages: list[BaseMessage] = [] self.tool_failure_used_fallbacks = set() self.current_failing_tool_name = "" self.current_tool_to_bind = None @@ -92,7 +93,9 @@ class FallbackHandler: cpm(message, title="Fallback Models") return final_models - def handle_failure(self, error: Exception, agent: CiaynAgent | CompiledGraph): + def handle_failure( + self, error: ToolExecutionError, agent: CiaynAgent | CompiledGraph + ): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. @@ -104,14 +107,7 @@ class FallbackHandler: return None failed_tool_call_name = self.extract_failed_tool_name(error) - if ( - self.current_failing_tool_name - and failed_tool_call_name != self.current_failing_tool_name - ): - logger.debug( - "New failing tool name identified. Resetting consecutive tool failures." - ) - self.reset_fallback_handler() + self._reset_on_new_failure(failed_tool_call_name) logger.debug( f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}" @@ -123,7 +119,7 @@ class FallbackHandler: ) if hasattr(error, "base_message") and error.base_message: - self.failed_messages.add(str(error.base_message)) + self.failed_messages.append(error.base_message) self.tool_failure_consecutive_failures += 1 logger.debug( @@ -172,6 +168,16 @@ class FallbackHandler: self.tool_failure_used_fallbacks.clear() self.fallback_tool_models = self._load_fallback_tool_models(self.config) + def _reset_on_new_failure(self, failed_tool_call_name): + if ( + self.current_failing_tool_name + and failed_tool_call_name != self.current_failing_tool_name + ): + logger.debug( + "New failing tool name identified. Resetting consecutive tool failures." + ) + self.reset_fallback_handler() + def extract_failed_tool_name(self, error: ToolExecutionError): if error.tool_name: failed_tool_call_name = error.tool_name @@ -234,7 +240,11 @@ class FallbackHandler: retry_model = binded_model.with_retry( stop_after_attempt=RETRY_FALLBACK_COUNT ) - response = retry_model.invoke(self.current_failing_tool_name) + # msg_list = [] + msg_list = self.construct_prompt_msg_list() + # response = retry_model.invoke(self.current_failing_tool_name) + response = retry_model.invoke(msg_list) + cpm(f"response={response}") self.tool_failure_used_fallbacks.add(fallback_model["model"]) tool_call = self.base_message_to_tool_call_dict(response) @@ -303,6 +313,33 @@ class FallbackHandler: ) return None + def construct_prompt_msg_list(self): + """ + Construct a list of chat messages for the fallback prompt. + The initial message instructs the assistant that it is a fallback tool caller. + Then includes the failed tool call messages from self.failed_messages. + Finally, it appends a human message asking it to retry calling the tool with correct valid arguments. + + Returns: + list: A list of chat messages. + """ + msg_list: list[BaseMessage] = [] + msg_list.append( + SystemMessage( + content="You are a fallback tool caller. Your only responsibility is to figure out what the previous failed tool call was trying to do and to call that tool with the correct format and arguments, using the provided failure messages." + ) + ) + if self.failed_messages: + # Convert to system messages to avoid API errors asking for correct msg structure + msg_list.extend([SystemMessage(str(msg)) for msg in self.failed_messages]) + + msg_list.append( + HumanMessage( + content=f"Retry using the tool '{self.current_failing_tool_name}' with improved arguments." + ) + ) + return msg_list + def invoke_prompt_tool_call(self, tool_call_request: dict): """ Invoke a tool call from a prompt-based fallback response.