disallow repeat tool calls

This commit is contained in:
AI Christianson 2025-03-04 08:33:22 -05:00
parent 5dfb41b000
commit fb030e9049
2 changed files with 341 additions and 62 deletions

View File

@ -98,7 +98,7 @@ class CiaynAgent:
- Memory management with configurable limits
"""
# List of tools that can be bundled (called multiple times in one response)
# List of tools that can be bundled together in a single response
BUNDLEABLE_TOOLS = [
"emit_expert_context",
"ask_expert",
@ -112,6 +112,23 @@ class CiaynAgent:
"request_research_and_implementation",
"run_shell_command",
]
# List of tools that should not be called repeatedly with the same parameters
# This prevents the agent from getting stuck in a loop calling the same tool
# with the same arguments multiple times
NO_REPEAT_TOOLS = [
"emit_expert_context",
"ask_expert",
"emit_key_facts",
"emit_key_snippet",
"request_implementation",
"read_file_tool",
"emit_research_notes",
"ripgrep_search",
"plan_implementation_completed",
"request_research_and_implementation",
"run_shell_command",
]
def __init__(
self,
@ -156,6 +173,12 @@ class CiaynAgent:
self.fallback_fixed_msg = HumanMessage(
"Fallback tool handler has fixed the tool call see: <fallback tool call result> for the output."
)
# Track the most recent tool call and parameters to prevent repeats
# This is used to detect and prevent identical tool calls with the same parameters
# to avoid redundant operations and encourage the agent to try different approaches
self.last_tool_call = None
self.last_tool_params = None
def _build_prompt(self, last_result: Optional[str] = None) -> str:
"""Build the prompt for the agent including available tools and context."""
@ -207,19 +230,11 @@ class CiaynAgent:
calls.append(call_str)
else:
# If any function is not bundleable, return just the original code
cpm(
f"Found multiple tool calls, but {func_name} is not bundleable.",
title="⚠ Non-bundleable Tools",
border_style="yellow"
)
logger.debug(f"Found multiple tool calls, but {func_name} is not bundleable.")
return [code]
if calls:
cpm(
f"Detected {len(calls)} bundleable tool calls.",
title="✓ Bundling Tools",
border_style="green"
)
logger.debug(f"Detected {len(calls)} bundleable tool calls.")
return calls
# Default case: just return the original code as a single element
@ -255,6 +270,79 @@ class CiaynAgent:
if validate_function_call_pattern(call):
functions_list = "\n\n".join(self.available_functions)
call = self._extract_tool_call(call, functions_list)
# Check for repeated tool calls with the same parameters
tool_name = self.extract_tool_name(call)
if tool_name in self.NO_REPEAT_TOOLS:
# Use AST to extract parameters
try:
tree = ast.parse(call)
if (isinstance(tree.body[0], ast.Expr) and
isinstance(tree.body[0].value, ast.Call)):
# Debug - print full AST structure
logger.debug(f"AST structure for bundled call: {ast.dump(tree.body[0].value)}")
# Extract and normalize parameter values
param_pairs = []
# Handle positional arguments
if tree.body[0].value.args:
logger.debug(f"Found positional args in bundled call: {[ast.unparse(arg) for arg in tree.body[0].value.args]}")
for i, arg in enumerate(tree.body[0].value.args):
arg_value = ast.unparse(arg)
# Normalize string literals by removing outer quotes
if ((arg_value.startswith("'") and arg_value.endswith("'")) or
(arg_value.startswith('"') and arg_value.endswith('"'))):
arg_value = arg_value[1:-1]
param_pairs.append((f"arg{i}", arg_value))
# Handle keyword arguments
for k in tree.body[0].value.keywords:
param_name = k.arg
param_value = ast.unparse(k.value)
# Debug - print each parameter
logger.debug(f"Processing parameter: {param_name} = {param_value}")
# Normalize string literals by removing outer quotes
if ((param_value.startswith("'") and param_value.endswith("'")) or
(param_value.startswith('"') and param_value.endswith('"'))):
param_value = param_value[1:-1]
param_pairs.append((param_name, param_value))
# Debug - print extracted parameters
logger.debug(f"Extracted parameters: {param_pairs}")
# Create a fingerprint of the call
current_call = (tool_name, str(sorted(param_pairs)))
# Debug information to help diagnose false positives
logger.debug(f"Tool call: {tool_name}\nCurrent call fingerprint: {current_call}\nLast call fingerprint: {self.last_tool_call}")
# If this fingerprint matches the last tool call, reject it
if current_call == self.last_tool_call:
logger.warning(f"Detected repeat call of {tool_name} with the same parameters.")
result = f"Repeat calls of {tool_name} with the same parameters are not allowed. You must try something different!"
results.append(result)
# Generate a random ID for this result
result_id = self._generate_random_id()
result_strings.append(f"<result-{result_id}>\n{result}\n</result-{result_id}>")
continue
# Update last tool call fingerprint for next comparison
self.last_tool_call = current_call
except Exception as e:
# If we can't parse parameters, just continue
# This ensures robustness when dealing with complex or malformed tool calls
logger.debug(f"Failed to parse parameters for duplicate detection: {str(e)}")
pass
# Execute the call and collect the result
result = eval(call.strip(), globals_dict)
@ -269,25 +357,85 @@ class CiaynAgent:
# Regular single tool call case
if validate_function_call_pattern(code):
cpm(
f"Tool call validation failed. Attempting to extract function call using LLM.",
title="⚠ Validation Warning",
border_style="yellow"
)
logger.debug(f"Tool call validation failed. Attempting to extract function call using LLM.")
functions_list = "\n\n".join(self.available_functions)
code = self._extract_tool_call(code, functions_list)
pass
# Check for repeated tool call with the same parameters (single tool case)
tool_name = self.extract_tool_name(code)
# If the tool is in the NO_REPEAT_TOOLS list, check for repeat calls
if tool_name in self.NO_REPEAT_TOOLS:
# Use AST to extract parameters
try:
tree = ast.parse(code)
if (isinstance(tree.body[0], ast.Expr) and
isinstance(tree.body[0].value, ast.Call)):
# Debug - print full AST structure
logger.debug(f"AST structure for single call: {ast.dump(tree.body[0].value)}")
# Extract and normalize parameter values
param_pairs = []
# Handle positional arguments
if tree.body[0].value.args:
logger.debug(f"Found positional args in single call: {[ast.unparse(arg) for arg in tree.body[0].value.args]}")
for i, arg in enumerate(tree.body[0].value.args):
arg_value = ast.unparse(arg)
# Normalize string literals by removing outer quotes
if ((arg_value.startswith("'") and arg_value.endswith("'")) or
(arg_value.startswith('"') and arg_value.endswith('"'))):
arg_value = arg_value[1:-1]
param_pairs.append((f"arg{i}", arg_value))
# Handle keyword arguments
for k in tree.body[0].value.keywords:
param_name = k.arg
param_value = ast.unparse(k.value)
# Debug - print each parameter
logger.debug(f"Processing parameter: {param_name} = {param_value}")
# Normalize string literals by removing outer quotes
if ((param_value.startswith("'") and param_value.endswith("'")) or
(param_value.startswith('"') and param_value.endswith('"'))):
param_value = param_value[1:-1]
param_pairs.append((param_name, param_value))
# Also check for positional arguments
if tree.body[0].value.args:
logger.debug(f"Found positional args: {[ast.unparse(arg) for arg in tree.body[0].value.args]}")
# Create a fingerprint of the call
current_call = (tool_name, str(sorted(param_pairs)))
# Debug information to help diagnose false positives
logger.debug(f"Tool call: {tool_name}\nCurrent call fingerprint: {current_call}\nLast call fingerprint: {self.last_tool_call}")
# If this fingerprint matches the last tool call, reject it
if current_call == self.last_tool_call:
logger.warning(f"Detected repeat call of {tool_name} with the same parameters.")
return f"Repeat calls of {tool_name} with the same parameters are not allowed. You must try something different!"
# Update last tool call fingerprint for next comparison
self.last_tool_call = current_call
except Exception as e:
# If we can't parse parameters, just continue with the tool execution
# This ensures robustness when dealing with complex or malformed tool calls
logger.debug(f"Failed to parse parameters for duplicate detection: {str(e)}")
pass
result = eval(code.strip(), globals_dict)
return result
except Exception as e:
error_msg = f"Error: {str(e)} \n Could not execute code: {code}"
tool_name = self.extract_tool_name(code)
cpm(
f"Tool execution failed for `{tool_name}`: {str(e)}",
title="❗ Tool Error",
border_style="red"
)
logger.warning(f"Tool execution failed for `{tool_name}`: {str(e)}")
raise ToolExecutionError(
error_msg, base_message=msg, tool_name=tool_name
) from e
@ -319,11 +467,7 @@ class CiaynAgent:
if not fallback_response:
self.chat_history.append(err_msg)
cpm(
f"Tool fallback was attempted but did not succeed. Original error: {str(e)}",
title="❗ Fallback Failed",
border_style="red bold"
)
logger.warning(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}")
return ""
self.chat_history.append(self.fallback_fixed_msg)
@ -333,11 +477,7 @@ class CiaynAgent:
msg += f"<fallback tool name>{e.tool_name}</fallback tool name>\n"
msg += f"<fallback tool call result>\n{fallback_response[1]}\n</fallback tool call result>\n"
cpm(
f"Fallback successful for tool `{e.tool_name}` after {DEFAULT_MAX_TOOL_FAILURES} consecutive failures.",
title="✓ Fallback Success",
border_style="green"
)
logger.info(f"Fallback successful for tool `{e.tool_name}` after {DEFAULT_MAX_TOOL_FAILURES} consecutive failures.")
return msg
@ -416,11 +556,7 @@ class CiaynAgent:
def _extract_tool_call(self, code: str, functions_list: str) -> str:
"""Extract a tool call from the code using a language model."""
model = get_model()
cpm(
f"Attempting to fix malformed tool call using LLM. Original code:\n```\n{code}\n```",
title="🔧 Tool Call Extraction",
border_style="blue"
)
logger.debug(f"Attempting to fix malformed tool call using LLM. Original code:\n```\n{code}\n```")
prompt = EXTRACT_TOOL_CALL_PROMPT.format(
functions_list=functions_list, code=code
)
@ -430,20 +566,12 @@ class CiaynAgent:
pattern = r"([\w_\-]+)\((.*?)\)"
matches = re.findall(pattern, response, re.DOTALL)
if len(matches) == 0:
cpm(
"Failed to extract a valid tool call from the model's response.",
title="❗ Extraction Failed",
border_style="red"
)
logger.warning("Failed to extract a valid tool call from the model's response.")
raise ToolExecutionError("Failed to extract tool call")
ma = matches[0][0].strip()
mb = matches[0][1].strip().replace("\n", " ")
fixed_code = f"{ma}({mb})"
cpm(
f"Successfully extracted tool call: `{fixed_code}`",
title="✓ Extraction Success",
border_style="green"
)
logger.debug(f"Successfully extracted tool call: `{fixed_code}`")
return fixed_code
def stream(
@ -467,11 +595,7 @@ class CiaynAgent:
empty_response_count += 1
logger.warning(f"Model returned empty response (count: {empty_response_count})")
cpm(
f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call.",
title="⚠ Empty Response",
border_style="yellow bold"
)
logger.warning(f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call.")
if empty_response_count >= max_empty_responses:
# If we've had too many empty responses, raise an error to break the loop
@ -480,11 +604,7 @@ class CiaynAgent:
mark_agent_crashed(crash_message)
logger.error(crash_message)
cpm(
"The agent has crashed after multiple failed attempts to generate a valid tool call.",
title="❗ Agent Crashed",
border_style="red bold"
)
logger.error("The agent has crashed after multiple failed attempts to generate a valid tool call.")
yield self._create_error_chunk(crash_message)
return
@ -504,11 +624,7 @@ class CiaynAgent:
yield {}
except ToolExecutionError as e:
cpm(
f"Tool execution error: {str(e)}. Attempting fallback...",
title="↻ Fallback Attempt",
border_style="yellow"
)
logger.warning(f"Tool execution error: {str(e)}. Attempting fallback...")
fallback_response = self.fallback_handler.handle_failure(
e, self, self.chat_history
)

View File

@ -0,0 +1,163 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
class TestNoRepeatTools:
@pytest.fixture
def mock_model(self):
mock = MagicMock()
return mock
@pytest.fixture
def mock_tool(self):
mock = MagicMock()
mock.func.__name__ = "test_tool"
mock.func.return_value = "Tool execution result"
return mock
@pytest.fixture
def agent(self, mock_model, mock_tool):
# Create the agent with our mock model and tool
agent = CiaynAgent(mock_model, [mock_tool])
# Add the test tool to the NO_REPEAT_TOOLS list
agent.NO_REPEAT_TOOLS.append("test_tool")
return agent
def test_repeat_tool_call_rejection(self, agent, mock_tool):
"""Test that repeat tool calls with the same parameters are rejected."""
# First call should succeed
first_message = AIMessage(content="test_tool(param1='value1', param2='value2')")
result1 = agent._execute_tool(first_message)
assert result1 == "Tool execution result"
# Second identical call should be rejected
second_message = AIMessage(content="test_tool(param1='value1', param2='value2')")
result2 = agent._execute_tool(second_message)
assert "Repeat calls of test_tool with the same parameters are not allowed" in result2
# Call with different parameters should succeed
third_message = AIMessage(content="test_tool(param1='value1', param2='different')")
result3 = agent._execute_tool(third_message)
assert result3 == "Tool execution result"
def test_bundled_calls_repeat_rejection(self, agent, mock_tool):
"""Test that repeat tool calls in bundled calls are rejected."""
# Mock the _detect_multiple_tool_calls method to simulate bundled calls
with patch.object(agent, '_detect_multiple_tool_calls') as mock_detect:
# Set up two bundled calls where the second one is a repeat
mock_detect.return_value = [
"test_tool(param1='value1', param2='value2')",
"test_tool(param1='value1', param2='value2')"
]
# Execute the bundled calls
message = AIMessage(content="test_tool(...)\ntest_tool(...)")
result = agent._execute_tool(message)
# First call should succeed, second should be rejected
assert "Tool execution result" in result
assert "Repeat calls of test_tool with the same parameters are not allowed" in result
def test_different_tool_not_affected(self, mock_model):
"""Test that tools not in NO_REPEAT_TOOLS list can be called repeatedly."""
# Create different mock tools for this test
mock_tool1 = MagicMock()
mock_tool1.func.__name__ = "non_repeat_tool"
mock_tool1.func.return_value = "Tool execution result"
# Create a fresh agent with our mock model and tool
agent = CiaynAgent(mock_model, [mock_tool1])
# First call
first_message = AIMessage(content="non_repeat_tool(param1='value1', param2='value2')")
result1 = agent._execute_tool(first_message)
assert result1 == "Tool execution result"
# Second identical call should also succeed because this tool is not in NO_REPEAT_TOOLS
second_message = AIMessage(content="non_repeat_tool(param1='value1', param2='value2')")
result2 = agent._execute_tool(second_message)
assert result2 == "Tool execution result"
def test_run_shell_command_detection(self, mock_model):
"""Test the shell command detection logic to ensure it's not creating false positives."""
# Create mock tools for this test
mock_tool1 = MagicMock()
mock_tool1.func.__name__ = "run_shell_command"
mock_tool1.func.return_value = "Shell command result"
# Create a fresh agent
agent = CiaynAgent(mock_model, [mock_tool1])
# First call to run_shell_command
first_message = AIMessage(content="run_shell_command(CommandLine='g++ main.cpp -o spinning_cube -lGL -lGLU -lglut', Cwd='/home/user', Blocking=True)")
result1 = agent._execute_tool(first_message)
assert result1 == "Shell command result"
# Second call with same command should be detected as duplicate
second_message = AIMessage(content="run_shell_command(CommandLine='g++ main.cpp -o spinning_cube -lGL -lGLU -lglut', Cwd='/home/user', Blocking=True)")
result2 = agent._execute_tool(second_message)
assert "Repeat calls of run_shell_command with the same parameters are not allowed" in result2
# Different command should work
third_message = AIMessage(content="run_shell_command(CommandLine='./spinning_cube', Cwd='/home/user', Blocking=True)")
result3 = agent._execute_tool(third_message)
assert result3 == "Shell command result"
# Test with same command but different parameter order (should still be detected as duplicate)
fourth_message = AIMessage(content="run_shell_command(Blocking=True, Cwd='/home/user', CommandLine='./spinning_cube')")
result4 = agent._execute_tool(fourth_message)
assert "Repeat calls of run_shell_command with the same parameters are not allowed" in result4
# Test with different boolean value (True vs true) - these should be treated as different
fifth_message = AIMessage(content="run_shell_command(CommandLine='ls -la', Cwd='/home/user', Blocking=True)") # different command
result5 = agent._execute_tool(fifth_message)
assert result5 == "Shell command result"
# Test with SafeToAutoRun parameter - should be treated as different
sixth_message = AIMessage(content="run_shell_command(CommandLine='ls -la', Cwd='/home/user', Blocking=True, SafeToAutoRun=True)")
result6 = agent._execute_tool(sixth_message)
assert result6 == "Shell command result"
def test_positional_args_detection(self, mock_model):
"""Test that positional arguments are properly included in fingerprinting."""
# Create mock tools for this test
mock_tool1 = MagicMock()
mock_tool1.func.__name__ = "test_tool"
mock_tool1.func.return_value = "Tool execution result"
# Create a fresh agent
agent = CiaynAgent(mock_model, [mock_tool1])
# Add the test tool to the NO_REPEAT_TOOLS list
agent.NO_REPEAT_TOOLS.append("test_tool")
# First call with positional args
first_message = AIMessage(content="test_tool('value1', 'value2')")
result1 = agent._execute_tool(first_message)
assert result1 == "Tool execution result"
# Second identical call should be rejected
second_message = AIMessage(content="test_tool('value1', 'value2')")
result2 = agent._execute_tool(second_message)
assert "Repeat calls of test_tool with the same parameters are not allowed" in result2
# Call with different positional args should succeed
third_message = AIMessage(content="test_tool('value1', 'different')")
result3 = agent._execute_tool(third_message)
assert result3 == "Tool execution result"
# Call with same values but as keyword args should be considered different
fourth_message = AIMessage(content="test_tool(param1='value1', param2='different')")
result4 = agent._execute_tool(fourth_message)
assert result4 == "Tool execution result"
# Mixed positional and keyword args
fifth_message = AIMessage(content="test_tool('value1', param2='different')")
result5 = agent._execute_tool(fifth_message)
assert result5 == "Tool execution result"
# Repeat of mixed call should be rejected
sixth_message = AIMessage(content="test_tool('value1', param2='different')")
result6 = agent._execute_tool(sixth_message)
assert "Repeat calls of test_tool with the same parameters are not allowed" in result6