feat(agent_utils.py): convert fallback response to string for prompt concatenation to ensure proper formatting

refactor(fallback_handler.py): change failed_messages from set to list for ordered message handling
refactor(fallback_handler.py): update handle_failure method to accept ToolExecutionError type for better type safety
refactor(fallback_handler.py): implement _reset_on_new_failure method to encapsulate failure reset logic
feat(fallback_handler.py): add construct_prompt_msg_list method to create structured message list for fallback tool calls
This commit is contained in:
Ariel Frischer 2025-02-12 13:39:25 -08:00
parent af9f95ceb1
commit 803acc6166
2 changed files with 50 additions and 13 deletions

View File

@ -878,7 +878,7 @@ def run_agent_with_retry(
except ToolExecutionError as e:
fallback_response = fallback_handler.handle_failure(e, agent)
if fallback_response:
prompt = original_prompt + "\n" + fallback_response
prompt = original_prompt + "\n" + str(fallback_response)
continue
except (KeyboardInterrupt, AgentInterrupt):
raise

View File

@ -4,6 +4,7 @@ import re
from langchain_core.tools import BaseTool
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import BaseMessage
from langchain_core.messages import SystemMessage, HumanMessage
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.config import (
@ -45,7 +46,7 @@ class FallbackHandler:
self.fallback_tool_models = self._load_fallback_tool_models(config)
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
self.tool_failure_consecutive_failures = 0
self.failed_messages = set()
self.failed_messages: list[BaseMessage] = []
self.tool_failure_used_fallbacks = set()
self.current_failing_tool_name = ""
self.current_tool_to_bind = None
@ -92,7 +93,9 @@ class FallbackHandler:
cpm(message, title="Fallback Models")
return final_models
def handle_failure(self, error: Exception, agent: CiaynAgent | CompiledGraph):
def handle_failure(
self, error: ToolExecutionError, agent: CiaynAgent | CompiledGraph
):
"""
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
@ -104,14 +107,7 @@ class FallbackHandler:
return None
failed_tool_call_name = self.extract_failed_tool_name(error)
if (
self.current_failing_tool_name
and failed_tool_call_name != self.current_failing_tool_name
):
logger.debug(
"New failing tool name identified. Resetting consecutive tool failures."
)
self.reset_fallback_handler()
self._reset_on_new_failure(failed_tool_call_name)
logger.debug(
f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}"
@ -123,7 +119,7 @@ class FallbackHandler:
)
if hasattr(error, "base_message") and error.base_message:
self.failed_messages.add(str(error.base_message))
self.failed_messages.append(error.base_message)
self.tool_failure_consecutive_failures += 1
logger.debug(
@ -172,6 +168,16 @@ class FallbackHandler:
self.tool_failure_used_fallbacks.clear()
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
def _reset_on_new_failure(self, failed_tool_call_name):
if (
self.current_failing_tool_name
and failed_tool_call_name != self.current_failing_tool_name
):
logger.debug(
"New failing tool name identified. Resetting consecutive tool failures."
)
self.reset_fallback_handler()
def extract_failed_tool_name(self, error: ToolExecutionError):
if error.tool_name:
failed_tool_call_name = error.tool_name
@ -234,7 +240,11 @@ class FallbackHandler:
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(self.current_failing_tool_name)
# msg_list = []
msg_list = self.construct_prompt_msg_list()
# response = retry_model.invoke(self.current_failing_tool_name)
response = retry_model.invoke(msg_list)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"])
tool_call = self.base_message_to_tool_call_dict(response)
@ -303,6 +313,33 @@ class FallbackHandler:
)
return None
def construct_prompt_msg_list(self):
"""
Construct a list of chat messages for the fallback prompt.
The initial message instructs the assistant that it is a fallback tool caller.
Then includes the failed tool call messages from self.failed_messages.
Finally, it appends a human message asking it to retry calling the tool with correct valid arguments.
Returns:
list: A list of chat messages.
"""
msg_list: list[BaseMessage] = []
msg_list.append(
SystemMessage(
content="You are a fallback tool caller. Your only responsibility is to figure out what the previous failed tool call was trying to do and to call that tool with the correct format and arguments, using the provided failure messages."
)
)
if self.failed_messages:
# Convert to system messages to avoid API errors asking for correct msg structure
msg_list.extend([SystemMessage(str(msg)) for msg in self.failed_messages])
msg_list.append(
HumanMessage(
content=f"Retry using the tool '{self.current_failing_tool_name}' with improved arguments."
)
)
return msg_list
def invoke_prompt_tool_call(self, tool_call_request: dict):
"""
Invoke a tool call from a prompt-based fallback response.