From fb030e90493e4a5abde6301bceef2c7620250eb9 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Tue, 4 Mar 2025 08:33:22 -0500 Subject: [PATCH] disallow repeat tool calls --- ra_aid/agent_backends/ciayn_agent.py | 240 +++++++++++++----- .../test_ciayn_no_repeat_tools.py | 163 ++++++++++++ 2 files changed, 341 insertions(+), 62 deletions(-) create mode 100644 tests/ra_aid/agent_backends/test_ciayn_no_repeat_tools.py diff --git a/ra_aid/agent_backends/ciayn_agent.py b/ra_aid/agent_backends/ciayn_agent.py index 2b9138d..70d6555 100644 --- a/ra_aid/agent_backends/ciayn_agent.py +++ b/ra_aid/agent_backends/ciayn_agent.py @@ -98,7 +98,7 @@ class CiaynAgent: - Memory management with configurable limits """ - # List of tools that can be bundled (called multiple times in one response) + # List of tools that can be bundled together in a single response BUNDLEABLE_TOOLS = [ "emit_expert_context", "ask_expert", @@ -112,6 +112,23 @@ class CiaynAgent: "request_research_and_implementation", "run_shell_command", ] + + # List of tools that should not be called repeatedly with the same parameters + # This prevents the agent from getting stuck in a loop calling the same tool + # with the same arguments multiple times + NO_REPEAT_TOOLS = [ + "emit_expert_context", + "ask_expert", + "emit_key_facts", + "emit_key_snippet", + "request_implementation", + "read_file_tool", + "emit_research_notes", + "ripgrep_search", + "plan_implementation_completed", + "request_research_and_implementation", + "run_shell_command", + ] def __init__( self, @@ -156,6 +173,12 @@ class CiaynAgent: self.fallback_fixed_msg = HumanMessage( "Fallback tool handler has fixed the tool call see: for the output." ) + + # Track the most recent tool call and parameters to prevent repeats + # This is used to detect and prevent identical tool calls with the same parameters + # to avoid redundant operations and encourage the agent to try different approaches + self.last_tool_call = None + self.last_tool_params = None def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -207,19 +230,11 @@ class CiaynAgent: calls.append(call_str) else: # If any function is not bundleable, return just the original code - cpm( - f"Found multiple tool calls, but {func_name} is not bundleable.", - title="⚠ Non-bundleable Tools", - border_style="yellow" - ) + logger.debug(f"Found multiple tool calls, but {func_name} is not bundleable.") return [code] if calls: - cpm( - f"Detected {len(calls)} bundleable tool calls.", - title="✓ Bundling Tools", - border_style="green" - ) + logger.debug(f"Detected {len(calls)} bundleable tool calls.") return calls # Default case: just return the original code as a single element @@ -255,6 +270,79 @@ class CiaynAgent: if validate_function_call_pattern(call): functions_list = "\n\n".join(self.available_functions) call = self._extract_tool_call(call, functions_list) + + # Check for repeated tool calls with the same parameters + tool_name = self.extract_tool_name(call) + + if tool_name in self.NO_REPEAT_TOOLS: + # Use AST to extract parameters + try: + tree = ast.parse(call) + if (isinstance(tree.body[0], ast.Expr) and + isinstance(tree.body[0].value, ast.Call)): + + # Debug - print full AST structure + logger.debug(f"AST structure for bundled call: {ast.dump(tree.body[0].value)}") + + # Extract and normalize parameter values + param_pairs = [] + + # Handle positional arguments + if tree.body[0].value.args: + logger.debug(f"Found positional args in bundled call: {[ast.unparse(arg) for arg in tree.body[0].value.args]}") + + for i, arg in enumerate(tree.body[0].value.args): + arg_value = ast.unparse(arg) + + # Normalize string literals by removing outer quotes + if ((arg_value.startswith("'") and arg_value.endswith("'")) or + (arg_value.startswith('"') and arg_value.endswith('"'))): + arg_value = arg_value[1:-1] + + param_pairs.append((f"arg{i}", arg_value)) + + # Handle keyword arguments + for k in tree.body[0].value.keywords: + param_name = k.arg + param_value = ast.unparse(k.value) + + # Debug - print each parameter + logger.debug(f"Processing parameter: {param_name} = {param_value}") + + # Normalize string literals by removing outer quotes + if ((param_value.startswith("'") and param_value.endswith("'")) or + (param_value.startswith('"') and param_value.endswith('"'))): + param_value = param_value[1:-1] + + param_pairs.append((param_name, param_value)) + + # Debug - print extracted parameters + logger.debug(f"Extracted parameters: {param_pairs}") + + # Create a fingerprint of the call + current_call = (tool_name, str(sorted(param_pairs))) + + # Debug information to help diagnose false positives + logger.debug(f"Tool call: {tool_name}\nCurrent call fingerprint: {current_call}\nLast call fingerprint: {self.last_tool_call}") + + # If this fingerprint matches the last tool call, reject it + if current_call == self.last_tool_call: + logger.warning(f"Detected repeat call of {tool_name} with the same parameters.") + result = f"Repeat calls of {tool_name} with the same parameters are not allowed. You must try something different!" + results.append(result) + + # Generate a random ID for this result + result_id = self._generate_random_id() + result_strings.append(f"\n{result}\n") + continue + + # Update last tool call fingerprint for next comparison + self.last_tool_call = current_call + except Exception as e: + # If we can't parse parameters, just continue + # This ensures robustness when dealing with complex or malformed tool calls + logger.debug(f"Failed to parse parameters for duplicate detection: {str(e)}") + pass # Execute the call and collect the result result = eval(call.strip(), globals_dict) @@ -269,25 +357,85 @@ class CiaynAgent: # Regular single tool call case if validate_function_call_pattern(code): - cpm( - f"Tool call validation failed. Attempting to extract function call using LLM.", - title="⚠ Validation Warning", - border_style="yellow" - ) + logger.debug(f"Tool call validation failed. Attempting to extract function call using LLM.") functions_list = "\n\n".join(self.available_functions) code = self._extract_tool_call(code, functions_list) - pass + # Check for repeated tool call with the same parameters (single tool case) + tool_name = self.extract_tool_name(code) + + # If the tool is in the NO_REPEAT_TOOLS list, check for repeat calls + if tool_name in self.NO_REPEAT_TOOLS: + # Use AST to extract parameters + try: + tree = ast.parse(code) + if (isinstance(tree.body[0], ast.Expr) and + isinstance(tree.body[0].value, ast.Call)): + + # Debug - print full AST structure + logger.debug(f"AST structure for single call: {ast.dump(tree.body[0].value)}") + + # Extract and normalize parameter values + param_pairs = [] + + # Handle positional arguments + if tree.body[0].value.args: + logger.debug(f"Found positional args in single call: {[ast.unparse(arg) for arg in tree.body[0].value.args]}") + + for i, arg in enumerate(tree.body[0].value.args): + arg_value = ast.unparse(arg) + + # Normalize string literals by removing outer quotes + if ((arg_value.startswith("'") and arg_value.endswith("'")) or + (arg_value.startswith('"') and arg_value.endswith('"'))): + arg_value = arg_value[1:-1] + + param_pairs.append((f"arg{i}", arg_value)) + + # Handle keyword arguments + for k in tree.body[0].value.keywords: + param_name = k.arg + param_value = ast.unparse(k.value) + + # Debug - print each parameter + logger.debug(f"Processing parameter: {param_name} = {param_value}") + + # Normalize string literals by removing outer quotes + if ((param_value.startswith("'") and param_value.endswith("'")) or + (param_value.startswith('"') and param_value.endswith('"'))): + param_value = param_value[1:-1] + + param_pairs.append((param_name, param_value)) + + # Also check for positional arguments + if tree.body[0].value.args: + logger.debug(f"Found positional args: {[ast.unparse(arg) for arg in tree.body[0].value.args]}") + + # Create a fingerprint of the call + current_call = (tool_name, str(sorted(param_pairs))) + + # Debug information to help diagnose false positives + logger.debug(f"Tool call: {tool_name}\nCurrent call fingerprint: {current_call}\nLast call fingerprint: {self.last_tool_call}") + + # If this fingerprint matches the last tool call, reject it + if current_call == self.last_tool_call: + logger.warning(f"Detected repeat call of {tool_name} with the same parameters.") + return f"Repeat calls of {tool_name} with the same parameters are not allowed. You must try something different!" + + # Update last tool call fingerprint for next comparison + self.last_tool_call = current_call + except Exception as e: + # If we can't parse parameters, just continue with the tool execution + # This ensures robustness when dealing with complex or malformed tool calls + logger.debug(f"Failed to parse parameters for duplicate detection: {str(e)}") + pass + result = eval(code.strip(), globals_dict) return result except Exception as e: error_msg = f"Error: {str(e)} \n Could not execute code: {code}" tool_name = self.extract_tool_name(code) - cpm( - f"Tool execution failed for `{tool_name}`: {str(e)}", - title="❗ Tool Error", - border_style="red" - ) + logger.warning(f"Tool execution failed for `{tool_name}`: {str(e)}") raise ToolExecutionError( error_msg, base_message=msg, tool_name=tool_name ) from e @@ -319,11 +467,7 @@ class CiaynAgent: if not fallback_response: self.chat_history.append(err_msg) - cpm( - f"Tool fallback was attempted but did not succeed. Original error: {str(e)}", - title="❗ Fallback Failed", - border_style="red bold" - ) + logger.warning(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}") return "" self.chat_history.append(self.fallback_fixed_msg) @@ -333,11 +477,7 @@ class CiaynAgent: msg += f"{e.tool_name}\n" msg += f"\n{fallback_response[1]}\n\n" - cpm( - f"Fallback successful for tool `{e.tool_name}` after {DEFAULT_MAX_TOOL_FAILURES} consecutive failures.", - title="✓ Fallback Success", - border_style="green" - ) + logger.info(f"Fallback successful for tool `{e.tool_name}` after {DEFAULT_MAX_TOOL_FAILURES} consecutive failures.") return msg @@ -416,11 +556,7 @@ class CiaynAgent: def _extract_tool_call(self, code: str, functions_list: str) -> str: """Extract a tool call from the code using a language model.""" model = get_model() - cpm( - f"Attempting to fix malformed tool call using LLM. Original code:\n```\n{code}\n```", - title="🔧 Tool Call Extraction", - border_style="blue" - ) + logger.debug(f"Attempting to fix malformed tool call using LLM. Original code:\n```\n{code}\n```") prompt = EXTRACT_TOOL_CALL_PROMPT.format( functions_list=functions_list, code=code ) @@ -430,20 +566,12 @@ class CiaynAgent: pattern = r"([\w_\-]+)\((.*?)\)" matches = re.findall(pattern, response, re.DOTALL) if len(matches) == 0: - cpm( - "Failed to extract a valid tool call from the model's response.", - title="❗ Extraction Failed", - border_style="red" - ) + logger.warning("Failed to extract a valid tool call from the model's response.") raise ToolExecutionError("Failed to extract tool call") ma = matches[0][0].strip() mb = matches[0][1].strip().replace("\n", " ") fixed_code = f"{ma}({mb})" - cpm( - f"Successfully extracted tool call: `{fixed_code}`", - title="✓ Extraction Success", - border_style="green" - ) + logger.debug(f"Successfully extracted tool call: `{fixed_code}`") return fixed_code def stream( @@ -467,11 +595,7 @@ class CiaynAgent: empty_response_count += 1 logger.warning(f"Model returned empty response (count: {empty_response_count})") - cpm( - f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call.", - title="⚠ Empty Response", - border_style="yellow bold" - ) + logger.warning(f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call.") if empty_response_count >= max_empty_responses: # If we've had too many empty responses, raise an error to break the loop @@ -480,11 +604,7 @@ class CiaynAgent: mark_agent_crashed(crash_message) logger.error(crash_message) - cpm( - "The agent has crashed after multiple failed attempts to generate a valid tool call.", - title="❗ Agent Crashed", - border_style="red bold" - ) + logger.error("The agent has crashed after multiple failed attempts to generate a valid tool call.") yield self._create_error_chunk(crash_message) return @@ -504,11 +624,7 @@ class CiaynAgent: yield {} except ToolExecutionError as e: - cpm( - f"Tool execution error: {str(e)}. Attempting fallback...", - title="↻ Fallback Attempt", - border_style="yellow" - ) + logger.warning(f"Tool execution error: {str(e)}. Attempting fallback...") fallback_response = self.fallback_handler.handle_failure( e, self, self.chat_history ) diff --git a/tests/ra_aid/agent_backends/test_ciayn_no_repeat_tools.py b/tests/ra_aid/agent_backends/test_ciayn_no_repeat_tools.py new file mode 100644 index 0000000..6cf4f04 --- /dev/null +++ b/tests/ra_aid/agent_backends/test_ciayn_no_repeat_tools.py @@ -0,0 +1,163 @@ +import pytest +from unittest.mock import MagicMock, patch +from langchain_core.messages import AIMessage + +from ra_aid.agent_backends.ciayn_agent import CiaynAgent + +class TestNoRepeatTools: + @pytest.fixture + def mock_model(self): + mock = MagicMock() + return mock + + @pytest.fixture + def mock_tool(self): + mock = MagicMock() + mock.func.__name__ = "test_tool" + mock.func.return_value = "Tool execution result" + return mock + + @pytest.fixture + def agent(self, mock_model, mock_tool): + # Create the agent with our mock model and tool + agent = CiaynAgent(mock_model, [mock_tool]) + # Add the test tool to the NO_REPEAT_TOOLS list + agent.NO_REPEAT_TOOLS.append("test_tool") + return agent + + def test_repeat_tool_call_rejection(self, agent, mock_tool): + """Test that repeat tool calls with the same parameters are rejected.""" + # First call should succeed + first_message = AIMessage(content="test_tool(param1='value1', param2='value2')") + result1 = agent._execute_tool(first_message) + assert result1 == "Tool execution result" + + # Second identical call should be rejected + second_message = AIMessage(content="test_tool(param1='value1', param2='value2')") + result2 = agent._execute_tool(second_message) + assert "Repeat calls of test_tool with the same parameters are not allowed" in result2 + + # Call with different parameters should succeed + third_message = AIMessage(content="test_tool(param1='value1', param2='different')") + result3 = agent._execute_tool(third_message) + assert result3 == "Tool execution result" + + def test_bundled_calls_repeat_rejection(self, agent, mock_tool): + """Test that repeat tool calls in bundled calls are rejected.""" + # Mock the _detect_multiple_tool_calls method to simulate bundled calls + with patch.object(agent, '_detect_multiple_tool_calls') as mock_detect: + # Set up two bundled calls where the second one is a repeat + mock_detect.return_value = [ + "test_tool(param1='value1', param2='value2')", + "test_tool(param1='value1', param2='value2')" + ] + + # Execute the bundled calls + message = AIMessage(content="test_tool(...)\ntest_tool(...)") + result = agent._execute_tool(message) + + # First call should succeed, second should be rejected + assert "Tool execution result" in result + assert "Repeat calls of test_tool with the same parameters are not allowed" in result + + def test_different_tool_not_affected(self, mock_model): + """Test that tools not in NO_REPEAT_TOOLS list can be called repeatedly.""" + # Create different mock tools for this test + mock_tool1 = MagicMock() + mock_tool1.func.__name__ = "non_repeat_tool" + mock_tool1.func.return_value = "Tool execution result" + + # Create a fresh agent with our mock model and tool + agent = CiaynAgent(mock_model, [mock_tool1]) + + # First call + first_message = AIMessage(content="non_repeat_tool(param1='value1', param2='value2')") + result1 = agent._execute_tool(first_message) + assert result1 == "Tool execution result" + + # Second identical call should also succeed because this tool is not in NO_REPEAT_TOOLS + second_message = AIMessage(content="non_repeat_tool(param1='value1', param2='value2')") + result2 = agent._execute_tool(second_message) + assert result2 == "Tool execution result" + + def test_run_shell_command_detection(self, mock_model): + """Test the shell command detection logic to ensure it's not creating false positives.""" + # Create mock tools for this test + mock_tool1 = MagicMock() + mock_tool1.func.__name__ = "run_shell_command" + mock_tool1.func.return_value = "Shell command result" + + # Create a fresh agent + agent = CiaynAgent(mock_model, [mock_tool1]) + + # First call to run_shell_command + first_message = AIMessage(content="run_shell_command(CommandLine='g++ main.cpp -o spinning_cube -lGL -lGLU -lglut', Cwd='/home/user', Blocking=True)") + result1 = agent._execute_tool(first_message) + assert result1 == "Shell command result" + + # Second call with same command should be detected as duplicate + second_message = AIMessage(content="run_shell_command(CommandLine='g++ main.cpp -o spinning_cube -lGL -lGLU -lglut', Cwd='/home/user', Blocking=True)") + result2 = agent._execute_tool(second_message) + assert "Repeat calls of run_shell_command with the same parameters are not allowed" in result2 + + # Different command should work + third_message = AIMessage(content="run_shell_command(CommandLine='./spinning_cube', Cwd='/home/user', Blocking=True)") + result3 = agent._execute_tool(third_message) + assert result3 == "Shell command result" + + # Test with same command but different parameter order (should still be detected as duplicate) + fourth_message = AIMessage(content="run_shell_command(Blocking=True, Cwd='/home/user', CommandLine='./spinning_cube')") + result4 = agent._execute_tool(fourth_message) + assert "Repeat calls of run_shell_command with the same parameters are not allowed" in result4 + + # Test with different boolean value (True vs true) - these should be treated as different + fifth_message = AIMessage(content="run_shell_command(CommandLine='ls -la', Cwd='/home/user', Blocking=True)") # different command + result5 = agent._execute_tool(fifth_message) + assert result5 == "Shell command result" + + # Test with SafeToAutoRun parameter - should be treated as different + sixth_message = AIMessage(content="run_shell_command(CommandLine='ls -la', Cwd='/home/user', Blocking=True, SafeToAutoRun=True)") + result6 = agent._execute_tool(sixth_message) + assert result6 == "Shell command result" + + def test_positional_args_detection(self, mock_model): + """Test that positional arguments are properly included in fingerprinting.""" + # Create mock tools for this test + mock_tool1 = MagicMock() + mock_tool1.func.__name__ = "test_tool" + mock_tool1.func.return_value = "Tool execution result" + + # Create a fresh agent + agent = CiaynAgent(mock_model, [mock_tool1]) + # Add the test tool to the NO_REPEAT_TOOLS list + agent.NO_REPEAT_TOOLS.append("test_tool") + + # First call with positional args + first_message = AIMessage(content="test_tool('value1', 'value2')") + result1 = agent._execute_tool(first_message) + assert result1 == "Tool execution result" + + # Second identical call should be rejected + second_message = AIMessage(content="test_tool('value1', 'value2')") + result2 = agent._execute_tool(second_message) + assert "Repeat calls of test_tool with the same parameters are not allowed" in result2 + + # Call with different positional args should succeed + third_message = AIMessage(content="test_tool('value1', 'different')") + result3 = agent._execute_tool(third_message) + assert result3 == "Tool execution result" + + # Call with same values but as keyword args should be considered different + fourth_message = AIMessage(content="test_tool(param1='value1', param2='different')") + result4 = agent._execute_tool(fourth_message) + assert result4 == "Tool execution result" + + # Mixed positional and keyword args + fifth_message = AIMessage(content="test_tool('value1', param2='different')") + result5 = agent._execute_tool(fifth_message) + assert result5 == "Tool execution result" + + # Repeat of mixed call should be rejected + sixth_message = AIMessage(content="test_tool('value1', param2='different')") + result6 = agent._execute_tool(sixth_message) + assert "Repeat calls of test_tool with the same parameters are not allowed" in result6