From 15ce534f8f9fff818122f8db4e85c6f0ce11247d Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Thu, 13 Feb 2025 20:09:21 -0800 Subject: [PATCH] feat(agent_utils.py): add debug print for config to assist in troubleshooting feat(ciayn_agent.py): pass config to CiaynAgent for improved functionality fix(ciayn_agent.py): handle tool execution errors more gracefully with msg_list feat(fallback_handler.py): enhance handle_failure method to utilize msg_list for better context feat(fallback_handler.py): implement init_msg_list to manage message history effectively test(test_fallback_handler.py): add unit tests for init_msg_list to ensure correct behavior --- ra_aid/agent_utils.py | 5 +-- ra_aid/agents/ciayn_agent.py | 20 ++++++++---- ra_aid/fallback_handler.py | 46 +++++++++++++++++++++------ tests/ra_aid/test_fallback_handler.py | 16 +++++++++- 4 files changed, 69 insertions(+), 18 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 74ea2d7..5f35a9a 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -270,6 +270,7 @@ def create_agent( """ try: config = _global_memory.get("config", {}) + print(f"config={config}") max_input_tokens = ( get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT ) @@ -281,7 +282,7 @@ def create_agent( return create_react_agent(model, tools, **agent_kwargs) else: logger.debug("Using CiaynAgent agent instance") - return CiaynAgent(model, tools, max_tokens=max_input_tokens) + return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config) except Exception as e: # Default to REACT agent if provider/model detection fails @@ -882,7 +883,7 @@ def _handle_fallback_response( """ if not fallback_handler: return - fallback_response = fallback_handler.handle_failure(error, agent) + fallback_response = fallback_handler.handle_failure(error, agent, msg_list) agent_type = get_agent_type(agent) if fallback_response and agent_type == "React": msg_list_response = [HumanMessage(str(msg)) for msg in fallback_response] diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index de38591..14caba3 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -139,18 +139,24 @@ class CiaynAgent: try: code = code.strip() + if code.startswith("```"): + code = code[3:].strip() + if code.endswith("```"): + code = code[:-3].strip() + + raise ToolExecutionError( + "err", base_message=msg, tool_name="ripgrep_search" + ) + logger.debug(f"_execute_tool: stripped code: {code}") # if the eval fails, try to extract it via a model call if validate_function_call_pattern(code): functions_list = "\n\n".join(self.available_functions) code = self._extract_tool_call(code, functions_list) + pass - logger.debug( - f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}" - ) result = eval(code.strip(), globals_dict) - logger.debug(f"_execute_tool: result: {result}") return result except Exception as e: error_msg = f"Error: {str(e)} \n Could not excute code: {code}" @@ -230,6 +236,7 @@ class CiaynAgent: raise ToolExecutionError("Failed to extract tool call") ma = matches[0][0].strip() mb = matches[0][1].strip().replace("\n", " ") + logger.debug(f"Extracted tool call: {ma}({mb})") return f"{ma}({mb})" def _trim_chat_history( @@ -284,13 +291,14 @@ class CiaynAgent: response = self.model.invoke([self.sys_message] + full_history) try: - # logger.debug(f"Code generated by agent: {response.content}") last_result = self._execute_tool(response) self.chat_history.append(response) self.fallback_handler.reset_fallback_handler() yield {} except ToolExecutionError as e: - fallback_response = self.fallback_handler.handle_failure(e, self) + fallback_response = self.fallback_handler.handle_failure( + e, self, self.chat_history + ) last_result = self.handle_fallback_response(fallback_response, e) yield {} diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 3d61098..63dc742 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -49,6 +49,7 @@ class FallbackHandler: self.failed_messages: list[BaseMessage] = [] self.current_failing_tool_name = "" self.current_tool_to_bind: None | BaseTool = None + self.msg_list: list[BaseMessage] = [] cpm( "Fallback models selected: " @@ -100,7 +101,9 @@ class FallbackHandler: ) return final_models - def handle_failure(self, error: ToolExecutionError, agent: RAgents): + def handle_failure( + self, error: ToolExecutionError, agent: RAgents, msg_list: list[BaseMessage] + ): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. @@ -111,6 +114,9 @@ class FallbackHandler: if not self.fallback_enabled: return None + if self.tool_failure_consecutive_failures == 0: + self.init_msg_list(msg_list) + failed_tool_call_name = self.extract_failed_tool_name(error) self._reset_on_new_failure(failed_tool_call_name) @@ -177,6 +183,7 @@ class FallbackHandler: self.fallback_tool_models = self._load_fallback_tool_models(self.config) self.current_failing_tool_name = "" self.current_tool_to_bind = None + self.msg_list = [] def _reset_on_new_failure(self, failed_tool_call_name): if ( @@ -296,22 +303,29 @@ class FallbackHandler: Returns: list: A list of chat messages. """ - msg_list: list[BaseMessage] = [] - msg_list.append( + prompt_msg_list: list[BaseMessage] = [] + prompt_msg_list.append( SystemMessage( content="You are a fallback tool caller. Your only responsibility is to figure out what the previous failed tool call was trying to do and to call that tool with the correct format and arguments, using the provided failure messages." ) ) + + # TODO: Have some way to use the correct message type in the future, dont just convert everything to system message. + # This may be difficult as each model type may require different chat structures and throw API errors. + prompt_msg_list.extend(SystemMessage(str(msg)) for msg in self.msg_list) + if self.failed_messages: # Convert to system messages to avoid API errors asking for correct msg structure - msg_list.extend([SystemMessage(str(msg)) for msg in self.failed_messages]) + prompt_msg_list.extend( + [SystemMessage(str(msg)) for msg in self.failed_messages] + ) - msg_list.append( + prompt_msg_list.append( HumanMessage( - content=f"Retry using the tool '{self.current_failing_tool_name}' with improved arguments." + content=f"Retry using the tool: '{self.current_failing_tool_name}' with correct arguments and formatting." ) ) - return msg_list + return prompt_msg_list def invoke_prompt_tool_call(self, tool_call_request: dict): """ @@ -323,12 +337,17 @@ class FallbackHandler: Returns: The result of invoking the tool. """ - tool_name_to_tool = {getattr(tool.func, "__name__", None): tool for tool in self.tools} + tool_name_to_tool = { + getattr(tool.func, "__name__", None): tool for tool in self.tools + } name = tool_call_request["name"] arguments = tool_call_request["arguments"] if name in tool_name_to_tool: return tool_name_to_tool[name].invoke(arguments) - elif self.current_tool_to_bind is not None and getattr(self.current_tool_to_bind.func, "__name__", None) == name: + elif ( + self.current_tool_to_bind is not None + and getattr(self.current_tool_to_bind.func, "__name__", None) == name + ): return self.current_tool_to_bind.invoke(arguments) else: raise Exception(f"Tool '{name}' not found in available tools.") @@ -406,3 +425,12 @@ class FallbackHandler: if fallback_response and agent_type == "React": return [SystemMessage(str(msg)) for msg in fallback_response] return None + + def init_msg_list(self, full_msg_list: list[BaseMessage]) -> None: + first_two = full_msg_list[:2] + last_two = full_msg_list[-2:] + merged = first_two.copy() + for msg in last_two: + if msg not in merged: + merged.append(msg) + self.msg_list = merged diff --git a/tests/ra_aid/test_fallback_handler.py b/tests/ra_aid/test_fallback_handler.py index 1558498..3273391 100644 --- a/tests/ra_aid/test_fallback_handler.py +++ b/tests/ra_aid/test_fallback_handler.py @@ -45,7 +45,7 @@ class TestFallbackHandler(unittest.TestCase): error_obj = ToolExecutionError( "Test error", base_message="dummy_call()", tool_name="dummy_tool" ) - self.fallback_handler.handle_failure(error_obj, self.agent) + self.fallback_handler.handle_failure(error_obj, self.agent, []) self.assertEqual( self.fallback_handler.tool_failure_consecutive_failures, initial_failures + 1, @@ -295,5 +295,19 @@ class TestFallbackHandler(unittest.TestCase): self.assertIsNone(response_non) + def test_init_msg_list_non_overlapping(self): + # Test when the first two and last two messages do not overlap. + full_list = ["msg1", "msg2", "msg3", "msg4", "msg5"] + self.fallback_handler.init_msg_list(full_list) + # Expected merged list: first two ("msg1", "msg2") plus last two ("msg4", "msg5") + self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"]) + + def test_init_msg_list_with_overlap(self): + # Test when the last two messages overlap with the first two. + full_list = ["msg1", "msg2", "msg1", "msg3"] + self.fallback_handler.init_msg_list(full_list) + # Expected merged list: first two ("msg1", "msg2") plus "msg3" from the last two, since "msg1" was already present. + self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg3"]) + if __name__ == "__main__": unittest.main()