disallow repeat tool calls
This commit is contained in:
parent
5dfb41b000
commit
fb030e9049
|
|
@ -98,7 +98,7 @@ class CiaynAgent:
|
||||||
- Memory management with configurable limits
|
- Memory management with configurable limits
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# List of tools that can be bundled (called multiple times in one response)
|
# List of tools that can be bundled together in a single response
|
||||||
BUNDLEABLE_TOOLS = [
|
BUNDLEABLE_TOOLS = [
|
||||||
"emit_expert_context",
|
"emit_expert_context",
|
||||||
"ask_expert",
|
"ask_expert",
|
||||||
|
|
@ -113,6 +113,23 @@ class CiaynAgent:
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# List of tools that should not be called repeatedly with the same parameters
|
||||||
|
# This prevents the agent from getting stuck in a loop calling the same tool
|
||||||
|
# with the same arguments multiple times
|
||||||
|
NO_REPEAT_TOOLS = [
|
||||||
|
"emit_expert_context",
|
||||||
|
"ask_expert",
|
||||||
|
"emit_key_facts",
|
||||||
|
"emit_key_snippet",
|
||||||
|
"request_implementation",
|
||||||
|
"read_file_tool",
|
||||||
|
"emit_research_notes",
|
||||||
|
"ripgrep_search",
|
||||||
|
"plan_implementation_completed",
|
||||||
|
"request_research_and_implementation",
|
||||||
|
"run_shell_command",
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: BaseChatModel,
|
model: BaseChatModel,
|
||||||
|
|
@ -157,6 +174,12 @@ class CiaynAgent:
|
||||||
"Fallback tool handler has fixed the tool call see: <fallback tool call result> for the output."
|
"Fallback tool handler has fixed the tool call see: <fallback tool call result> for the output."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track the most recent tool call and parameters to prevent repeats
|
||||||
|
# This is used to detect and prevent identical tool calls with the same parameters
|
||||||
|
# to avoid redundant operations and encourage the agent to try different approaches
|
||||||
|
self.last_tool_call = None
|
||||||
|
self.last_tool_params = None
|
||||||
|
|
||||||
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
||||||
"""Build the prompt for the agent including available tools and context."""
|
"""Build the prompt for the agent including available tools and context."""
|
||||||
# Add last result section if provided
|
# Add last result section if provided
|
||||||
|
|
@ -207,19 +230,11 @@ class CiaynAgent:
|
||||||
calls.append(call_str)
|
calls.append(call_str)
|
||||||
else:
|
else:
|
||||||
# If any function is not bundleable, return just the original code
|
# If any function is not bundleable, return just the original code
|
||||||
cpm(
|
logger.debug(f"Found multiple tool calls, but {func_name} is not bundleable.")
|
||||||
f"Found multiple tool calls, but {func_name} is not bundleable.",
|
|
||||||
title="⚠ Non-bundleable Tools",
|
|
||||||
border_style="yellow"
|
|
||||||
)
|
|
||||||
return [code]
|
return [code]
|
||||||
|
|
||||||
if calls:
|
if calls:
|
||||||
cpm(
|
logger.debug(f"Detected {len(calls)} bundleable tool calls.")
|
||||||
f"Detected {len(calls)} bundleable tool calls.",
|
|
||||||
title="✓ Bundling Tools",
|
|
||||||
border_style="green"
|
|
||||||
)
|
|
||||||
return calls
|
return calls
|
||||||
|
|
||||||
# Default case: just return the original code as a single element
|
# Default case: just return the original code as a single element
|
||||||
|
|
@ -256,6 +271,79 @@ class CiaynAgent:
|
||||||
functions_list = "\n\n".join(self.available_functions)
|
functions_list = "\n\n".join(self.available_functions)
|
||||||
call = self._extract_tool_call(call, functions_list)
|
call = self._extract_tool_call(call, functions_list)
|
||||||
|
|
||||||
|
# Check for repeated tool calls with the same parameters
|
||||||
|
tool_name = self.extract_tool_name(call)
|
||||||
|
|
||||||
|
if tool_name in self.NO_REPEAT_TOOLS:
|
||||||
|
# Use AST to extract parameters
|
||||||
|
try:
|
||||||
|
tree = ast.parse(call)
|
||||||
|
if (isinstance(tree.body[0], ast.Expr) and
|
||||||
|
isinstance(tree.body[0].value, ast.Call)):
|
||||||
|
|
||||||
|
# Debug - print full AST structure
|
||||||
|
logger.debug(f"AST structure for bundled call: {ast.dump(tree.body[0].value)}")
|
||||||
|
|
||||||
|
# Extract and normalize parameter values
|
||||||
|
param_pairs = []
|
||||||
|
|
||||||
|
# Handle positional arguments
|
||||||
|
if tree.body[0].value.args:
|
||||||
|
logger.debug(f"Found positional args in bundled call: {[ast.unparse(arg) for arg in tree.body[0].value.args]}")
|
||||||
|
|
||||||
|
for i, arg in enumerate(tree.body[0].value.args):
|
||||||
|
arg_value = ast.unparse(arg)
|
||||||
|
|
||||||
|
# Normalize string literals by removing outer quotes
|
||||||
|
if ((arg_value.startswith("'") and arg_value.endswith("'")) or
|
||||||
|
(arg_value.startswith('"') and arg_value.endswith('"'))):
|
||||||
|
arg_value = arg_value[1:-1]
|
||||||
|
|
||||||
|
param_pairs.append((f"arg{i}", arg_value))
|
||||||
|
|
||||||
|
# Handle keyword arguments
|
||||||
|
for k in tree.body[0].value.keywords:
|
||||||
|
param_name = k.arg
|
||||||
|
param_value = ast.unparse(k.value)
|
||||||
|
|
||||||
|
# Debug - print each parameter
|
||||||
|
logger.debug(f"Processing parameter: {param_name} = {param_value}")
|
||||||
|
|
||||||
|
# Normalize string literals by removing outer quotes
|
||||||
|
if ((param_value.startswith("'") and param_value.endswith("'")) or
|
||||||
|
(param_value.startswith('"') and param_value.endswith('"'))):
|
||||||
|
param_value = param_value[1:-1]
|
||||||
|
|
||||||
|
param_pairs.append((param_name, param_value))
|
||||||
|
|
||||||
|
# Debug - print extracted parameters
|
||||||
|
logger.debug(f"Extracted parameters: {param_pairs}")
|
||||||
|
|
||||||
|
# Create a fingerprint of the call
|
||||||
|
current_call = (tool_name, str(sorted(param_pairs)))
|
||||||
|
|
||||||
|
# Debug information to help diagnose false positives
|
||||||
|
logger.debug(f"Tool call: {tool_name}\nCurrent call fingerprint: {current_call}\nLast call fingerprint: {self.last_tool_call}")
|
||||||
|
|
||||||
|
# If this fingerprint matches the last tool call, reject it
|
||||||
|
if current_call == self.last_tool_call:
|
||||||
|
logger.warning(f"Detected repeat call of {tool_name} with the same parameters.")
|
||||||
|
result = f"Repeat calls of {tool_name} with the same parameters are not allowed. You must try something different!"
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Generate a random ID for this result
|
||||||
|
result_id = self._generate_random_id()
|
||||||
|
result_strings.append(f"<result-{result_id}>\n{result}\n</result-{result_id}>")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update last tool call fingerprint for next comparison
|
||||||
|
self.last_tool_call = current_call
|
||||||
|
except Exception as e:
|
||||||
|
# If we can't parse parameters, just continue
|
||||||
|
# This ensures robustness when dealing with complex or malformed tool calls
|
||||||
|
logger.debug(f"Failed to parse parameters for duplicate detection: {str(e)}")
|
||||||
|
pass
|
||||||
|
|
||||||
# Execute the call and collect the result
|
# Execute the call and collect the result
|
||||||
result = eval(call.strip(), globals_dict)
|
result = eval(call.strip(), globals_dict)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
@ -269,25 +357,85 @@ class CiaynAgent:
|
||||||
|
|
||||||
# Regular single tool call case
|
# Regular single tool call case
|
||||||
if validate_function_call_pattern(code):
|
if validate_function_call_pattern(code):
|
||||||
cpm(
|
logger.debug(f"Tool call validation failed. Attempting to extract function call using LLM.")
|
||||||
f"Tool call validation failed. Attempting to extract function call using LLM.",
|
|
||||||
title="⚠ Validation Warning",
|
|
||||||
border_style="yellow"
|
|
||||||
)
|
|
||||||
functions_list = "\n\n".join(self.available_functions)
|
functions_list = "\n\n".join(self.available_functions)
|
||||||
code = self._extract_tool_call(code, functions_list)
|
code = self._extract_tool_call(code, functions_list)
|
||||||
pass
|
|
||||||
|
# Check for repeated tool call with the same parameters (single tool case)
|
||||||
|
tool_name = self.extract_tool_name(code)
|
||||||
|
|
||||||
|
# If the tool is in the NO_REPEAT_TOOLS list, check for repeat calls
|
||||||
|
if tool_name in self.NO_REPEAT_TOOLS:
|
||||||
|
# Use AST to extract parameters
|
||||||
|
try:
|
||||||
|
tree = ast.parse(code)
|
||||||
|
if (isinstance(tree.body[0], ast.Expr) and
|
||||||
|
isinstance(tree.body[0].value, ast.Call)):
|
||||||
|
|
||||||
|
# Debug - print full AST structure
|
||||||
|
logger.debug(f"AST structure for single call: {ast.dump(tree.body[0].value)}")
|
||||||
|
|
||||||
|
# Extract and normalize parameter values
|
||||||
|
param_pairs = []
|
||||||
|
|
||||||
|
# Handle positional arguments
|
||||||
|
if tree.body[0].value.args:
|
||||||
|
logger.debug(f"Found positional args in single call: {[ast.unparse(arg) for arg in tree.body[0].value.args]}")
|
||||||
|
|
||||||
|
for i, arg in enumerate(tree.body[0].value.args):
|
||||||
|
arg_value = ast.unparse(arg)
|
||||||
|
|
||||||
|
# Normalize string literals by removing outer quotes
|
||||||
|
if ((arg_value.startswith("'") and arg_value.endswith("'")) or
|
||||||
|
(arg_value.startswith('"') and arg_value.endswith('"'))):
|
||||||
|
arg_value = arg_value[1:-1]
|
||||||
|
|
||||||
|
param_pairs.append((f"arg{i}", arg_value))
|
||||||
|
|
||||||
|
# Handle keyword arguments
|
||||||
|
for k in tree.body[0].value.keywords:
|
||||||
|
param_name = k.arg
|
||||||
|
param_value = ast.unparse(k.value)
|
||||||
|
|
||||||
|
# Debug - print each parameter
|
||||||
|
logger.debug(f"Processing parameter: {param_name} = {param_value}")
|
||||||
|
|
||||||
|
# Normalize string literals by removing outer quotes
|
||||||
|
if ((param_value.startswith("'") and param_value.endswith("'")) or
|
||||||
|
(param_value.startswith('"') and param_value.endswith('"'))):
|
||||||
|
param_value = param_value[1:-1]
|
||||||
|
|
||||||
|
param_pairs.append((param_name, param_value))
|
||||||
|
|
||||||
|
# Also check for positional arguments
|
||||||
|
if tree.body[0].value.args:
|
||||||
|
logger.debug(f"Found positional args: {[ast.unparse(arg) for arg in tree.body[0].value.args]}")
|
||||||
|
|
||||||
|
# Create a fingerprint of the call
|
||||||
|
current_call = (tool_name, str(sorted(param_pairs)))
|
||||||
|
|
||||||
|
# Debug information to help diagnose false positives
|
||||||
|
logger.debug(f"Tool call: {tool_name}\nCurrent call fingerprint: {current_call}\nLast call fingerprint: {self.last_tool_call}")
|
||||||
|
|
||||||
|
# If this fingerprint matches the last tool call, reject it
|
||||||
|
if current_call == self.last_tool_call:
|
||||||
|
logger.warning(f"Detected repeat call of {tool_name} with the same parameters.")
|
||||||
|
return f"Repeat calls of {tool_name} with the same parameters are not allowed. You must try something different!"
|
||||||
|
|
||||||
|
# Update last tool call fingerprint for next comparison
|
||||||
|
self.last_tool_call = current_call
|
||||||
|
except Exception as e:
|
||||||
|
# If we can't parse parameters, just continue with the tool execution
|
||||||
|
# This ensures robustness when dealing with complex or malformed tool calls
|
||||||
|
logger.debug(f"Failed to parse parameters for duplicate detection: {str(e)}")
|
||||||
|
pass
|
||||||
|
|
||||||
result = eval(code.strip(), globals_dict)
|
result = eval(code.strip(), globals_dict)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error: {str(e)} \n Could not execute code: {code}"
|
error_msg = f"Error: {str(e)} \n Could not execute code: {code}"
|
||||||
tool_name = self.extract_tool_name(code)
|
tool_name = self.extract_tool_name(code)
|
||||||
cpm(
|
logger.warning(f"Tool execution failed for `{tool_name}`: {str(e)}")
|
||||||
f"Tool execution failed for `{tool_name}`: {str(e)}",
|
|
||||||
title="❗ Tool Error",
|
|
||||||
border_style="red"
|
|
||||||
)
|
|
||||||
raise ToolExecutionError(
|
raise ToolExecutionError(
|
||||||
error_msg, base_message=msg, tool_name=tool_name
|
error_msg, base_message=msg, tool_name=tool_name
|
||||||
) from e
|
) from e
|
||||||
|
|
@ -319,11 +467,7 @@ class CiaynAgent:
|
||||||
|
|
||||||
if not fallback_response:
|
if not fallback_response:
|
||||||
self.chat_history.append(err_msg)
|
self.chat_history.append(err_msg)
|
||||||
cpm(
|
logger.warning(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}")
|
||||||
f"Tool fallback was attempted but did not succeed. Original error: {str(e)}",
|
|
||||||
title="❗ Fallback Failed",
|
|
||||||
border_style="red bold"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
self.chat_history.append(self.fallback_fixed_msg)
|
self.chat_history.append(self.fallback_fixed_msg)
|
||||||
|
|
@ -333,11 +477,7 @@ class CiaynAgent:
|
||||||
msg += f"<fallback tool name>{e.tool_name}</fallback tool name>\n"
|
msg += f"<fallback tool name>{e.tool_name}</fallback tool name>\n"
|
||||||
msg += f"<fallback tool call result>\n{fallback_response[1]}\n</fallback tool call result>\n"
|
msg += f"<fallback tool call result>\n{fallback_response[1]}\n</fallback tool call result>\n"
|
||||||
|
|
||||||
cpm(
|
logger.info(f"Fallback successful for tool `{e.tool_name}` after {DEFAULT_MAX_TOOL_FAILURES} consecutive failures.")
|
||||||
f"Fallback successful for tool `{e.tool_name}` after {DEFAULT_MAX_TOOL_FAILURES} consecutive failures.",
|
|
||||||
title="✓ Fallback Success",
|
|
||||||
border_style="green"
|
|
||||||
)
|
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
@ -416,11 +556,7 @@ class CiaynAgent:
|
||||||
def _extract_tool_call(self, code: str, functions_list: str) -> str:
|
def _extract_tool_call(self, code: str, functions_list: str) -> str:
|
||||||
"""Extract a tool call from the code using a language model."""
|
"""Extract a tool call from the code using a language model."""
|
||||||
model = get_model()
|
model = get_model()
|
||||||
cpm(
|
logger.debug(f"Attempting to fix malformed tool call using LLM. Original code:\n```\n{code}\n```")
|
||||||
f"Attempting to fix malformed tool call using LLM. Original code:\n```\n{code}\n```",
|
|
||||||
title="🔧 Tool Call Extraction",
|
|
||||||
border_style="blue"
|
|
||||||
)
|
|
||||||
prompt = EXTRACT_TOOL_CALL_PROMPT.format(
|
prompt = EXTRACT_TOOL_CALL_PROMPT.format(
|
||||||
functions_list=functions_list, code=code
|
functions_list=functions_list, code=code
|
||||||
)
|
)
|
||||||
|
|
@ -430,20 +566,12 @@ class CiaynAgent:
|
||||||
pattern = r"([\w_\-]+)\((.*?)\)"
|
pattern = r"([\w_\-]+)\((.*?)\)"
|
||||||
matches = re.findall(pattern, response, re.DOTALL)
|
matches = re.findall(pattern, response, re.DOTALL)
|
||||||
if len(matches) == 0:
|
if len(matches) == 0:
|
||||||
cpm(
|
logger.warning("Failed to extract a valid tool call from the model's response.")
|
||||||
"Failed to extract a valid tool call from the model's response.",
|
|
||||||
title="❗ Extraction Failed",
|
|
||||||
border_style="red"
|
|
||||||
)
|
|
||||||
raise ToolExecutionError("Failed to extract tool call")
|
raise ToolExecutionError("Failed to extract tool call")
|
||||||
ma = matches[0][0].strip()
|
ma = matches[0][0].strip()
|
||||||
mb = matches[0][1].strip().replace("\n", " ")
|
mb = matches[0][1].strip().replace("\n", " ")
|
||||||
fixed_code = f"{ma}({mb})"
|
fixed_code = f"{ma}({mb})"
|
||||||
cpm(
|
logger.debug(f"Successfully extracted tool call: `{fixed_code}`")
|
||||||
f"Successfully extracted tool call: `{fixed_code}`",
|
|
||||||
title="✓ Extraction Success",
|
|
||||||
border_style="green"
|
|
||||||
)
|
|
||||||
return fixed_code
|
return fixed_code
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
|
|
@ -467,11 +595,7 @@ class CiaynAgent:
|
||||||
empty_response_count += 1
|
empty_response_count += 1
|
||||||
logger.warning(f"Model returned empty response (count: {empty_response_count})")
|
logger.warning(f"Model returned empty response (count: {empty_response_count})")
|
||||||
|
|
||||||
cpm(
|
logger.warning(f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call.")
|
||||||
f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call.",
|
|
||||||
title="⚠ Empty Response",
|
|
||||||
border_style="yellow bold"
|
|
||||||
)
|
|
||||||
|
|
||||||
if empty_response_count >= max_empty_responses:
|
if empty_response_count >= max_empty_responses:
|
||||||
# If we've had too many empty responses, raise an error to break the loop
|
# If we've had too many empty responses, raise an error to break the loop
|
||||||
|
|
@ -480,11 +604,7 @@ class CiaynAgent:
|
||||||
mark_agent_crashed(crash_message)
|
mark_agent_crashed(crash_message)
|
||||||
logger.error(crash_message)
|
logger.error(crash_message)
|
||||||
|
|
||||||
cpm(
|
logger.error("The agent has crashed after multiple failed attempts to generate a valid tool call.")
|
||||||
"The agent has crashed after multiple failed attempts to generate a valid tool call.",
|
|
||||||
title="❗ Agent Crashed",
|
|
||||||
border_style="red bold"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self._create_error_chunk(crash_message)
|
yield self._create_error_chunk(crash_message)
|
||||||
return
|
return
|
||||||
|
|
@ -504,11 +624,7 @@ class CiaynAgent:
|
||||||
yield {}
|
yield {}
|
||||||
|
|
||||||
except ToolExecutionError as e:
|
except ToolExecutionError as e:
|
||||||
cpm(
|
logger.warning(f"Tool execution error: {str(e)}. Attempting fallback...")
|
||||||
f"Tool execution error: {str(e)}. Attempting fallback...",
|
|
||||||
title="↻ Fallback Attempt",
|
|
||||||
border_style="yellow"
|
|
||||||
)
|
|
||||||
fallback_response = self.fallback_handler.handle_failure(
|
fallback_response = self.fallback_handler.handle_failure(
|
||||||
e, self, self.chat_history
|
e, self, self.chat_history
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,163 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||||
|
|
||||||
|
class TestNoRepeatTools:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_model(self):
|
||||||
|
mock = MagicMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool(self):
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.func.__name__ = "test_tool"
|
||||||
|
mock.func.return_value = "Tool execution result"
|
||||||
|
return mock
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent(self, mock_model, mock_tool):
|
||||||
|
# Create the agent with our mock model and tool
|
||||||
|
agent = CiaynAgent(mock_model, [mock_tool])
|
||||||
|
# Add the test tool to the NO_REPEAT_TOOLS list
|
||||||
|
agent.NO_REPEAT_TOOLS.append("test_tool")
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def test_repeat_tool_call_rejection(self, agent, mock_tool):
|
||||||
|
"""Test that repeat tool calls with the same parameters are rejected."""
|
||||||
|
# First call should succeed
|
||||||
|
first_message = AIMessage(content="test_tool(param1='value1', param2='value2')")
|
||||||
|
result1 = agent._execute_tool(first_message)
|
||||||
|
assert result1 == "Tool execution result"
|
||||||
|
|
||||||
|
# Second identical call should be rejected
|
||||||
|
second_message = AIMessage(content="test_tool(param1='value1', param2='value2')")
|
||||||
|
result2 = agent._execute_tool(second_message)
|
||||||
|
assert "Repeat calls of test_tool with the same parameters are not allowed" in result2
|
||||||
|
|
||||||
|
# Call with different parameters should succeed
|
||||||
|
third_message = AIMessage(content="test_tool(param1='value1', param2='different')")
|
||||||
|
result3 = agent._execute_tool(third_message)
|
||||||
|
assert result3 == "Tool execution result"
|
||||||
|
|
||||||
|
def test_bundled_calls_repeat_rejection(self, agent, mock_tool):
|
||||||
|
"""Test that repeat tool calls in bundled calls are rejected."""
|
||||||
|
# Mock the _detect_multiple_tool_calls method to simulate bundled calls
|
||||||
|
with patch.object(agent, '_detect_multiple_tool_calls') as mock_detect:
|
||||||
|
# Set up two bundled calls where the second one is a repeat
|
||||||
|
mock_detect.return_value = [
|
||||||
|
"test_tool(param1='value1', param2='value2')",
|
||||||
|
"test_tool(param1='value1', param2='value2')"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute the bundled calls
|
||||||
|
message = AIMessage(content="test_tool(...)\ntest_tool(...)")
|
||||||
|
result = agent._execute_tool(message)
|
||||||
|
|
||||||
|
# First call should succeed, second should be rejected
|
||||||
|
assert "Tool execution result" in result
|
||||||
|
assert "Repeat calls of test_tool with the same parameters are not allowed" in result
|
||||||
|
|
||||||
|
def test_different_tool_not_affected(self, mock_model):
|
||||||
|
"""Test that tools not in NO_REPEAT_TOOLS list can be called repeatedly."""
|
||||||
|
# Create different mock tools for this test
|
||||||
|
mock_tool1 = MagicMock()
|
||||||
|
mock_tool1.func.__name__ = "non_repeat_tool"
|
||||||
|
mock_tool1.func.return_value = "Tool execution result"
|
||||||
|
|
||||||
|
# Create a fresh agent with our mock model and tool
|
||||||
|
agent = CiaynAgent(mock_model, [mock_tool1])
|
||||||
|
|
||||||
|
# First call
|
||||||
|
first_message = AIMessage(content="non_repeat_tool(param1='value1', param2='value2')")
|
||||||
|
result1 = agent._execute_tool(first_message)
|
||||||
|
assert result1 == "Tool execution result"
|
||||||
|
|
||||||
|
# Second identical call should also succeed because this tool is not in NO_REPEAT_TOOLS
|
||||||
|
second_message = AIMessage(content="non_repeat_tool(param1='value1', param2='value2')")
|
||||||
|
result2 = agent._execute_tool(second_message)
|
||||||
|
assert result2 == "Tool execution result"
|
||||||
|
|
||||||
|
def test_run_shell_command_detection(self, mock_model):
|
||||||
|
"""Test the shell command detection logic to ensure it's not creating false positives."""
|
||||||
|
# Create mock tools for this test
|
||||||
|
mock_tool1 = MagicMock()
|
||||||
|
mock_tool1.func.__name__ = "run_shell_command"
|
||||||
|
mock_tool1.func.return_value = "Shell command result"
|
||||||
|
|
||||||
|
# Create a fresh agent
|
||||||
|
agent = CiaynAgent(mock_model, [mock_tool1])
|
||||||
|
|
||||||
|
# First call to run_shell_command
|
||||||
|
first_message = AIMessage(content="run_shell_command(CommandLine='g++ main.cpp -o spinning_cube -lGL -lGLU -lglut', Cwd='/home/user', Blocking=True)")
|
||||||
|
result1 = agent._execute_tool(first_message)
|
||||||
|
assert result1 == "Shell command result"
|
||||||
|
|
||||||
|
# Second call with same command should be detected as duplicate
|
||||||
|
second_message = AIMessage(content="run_shell_command(CommandLine='g++ main.cpp -o spinning_cube -lGL -lGLU -lglut', Cwd='/home/user', Blocking=True)")
|
||||||
|
result2 = agent._execute_tool(second_message)
|
||||||
|
assert "Repeat calls of run_shell_command with the same parameters are not allowed" in result2
|
||||||
|
|
||||||
|
# Different command should work
|
||||||
|
third_message = AIMessage(content="run_shell_command(CommandLine='./spinning_cube', Cwd='/home/user', Blocking=True)")
|
||||||
|
result3 = agent._execute_tool(third_message)
|
||||||
|
assert result3 == "Shell command result"
|
||||||
|
|
||||||
|
# Test with same command but different parameter order (should still be detected as duplicate)
|
||||||
|
fourth_message = AIMessage(content="run_shell_command(Blocking=True, Cwd='/home/user', CommandLine='./spinning_cube')")
|
||||||
|
result4 = agent._execute_tool(fourth_message)
|
||||||
|
assert "Repeat calls of run_shell_command with the same parameters are not allowed" in result4
|
||||||
|
|
||||||
|
# Test with different boolean value (True vs true) - these should be treated as different
|
||||||
|
fifth_message = AIMessage(content="run_shell_command(CommandLine='ls -la', Cwd='/home/user', Blocking=True)") # different command
|
||||||
|
result5 = agent._execute_tool(fifth_message)
|
||||||
|
assert result5 == "Shell command result"
|
||||||
|
|
||||||
|
# Test with SafeToAutoRun parameter - should be treated as different
|
||||||
|
sixth_message = AIMessage(content="run_shell_command(CommandLine='ls -la', Cwd='/home/user', Blocking=True, SafeToAutoRun=True)")
|
||||||
|
result6 = agent._execute_tool(sixth_message)
|
||||||
|
assert result6 == "Shell command result"
|
||||||
|
|
||||||
|
def test_positional_args_detection(self, mock_model):
|
||||||
|
"""Test that positional arguments are properly included in fingerprinting."""
|
||||||
|
# Create mock tools for this test
|
||||||
|
mock_tool1 = MagicMock()
|
||||||
|
mock_tool1.func.__name__ = "test_tool"
|
||||||
|
mock_tool1.func.return_value = "Tool execution result"
|
||||||
|
|
||||||
|
# Create a fresh agent
|
||||||
|
agent = CiaynAgent(mock_model, [mock_tool1])
|
||||||
|
# Add the test tool to the NO_REPEAT_TOOLS list
|
||||||
|
agent.NO_REPEAT_TOOLS.append("test_tool")
|
||||||
|
|
||||||
|
# First call with positional args
|
||||||
|
first_message = AIMessage(content="test_tool('value1', 'value2')")
|
||||||
|
result1 = agent._execute_tool(first_message)
|
||||||
|
assert result1 == "Tool execution result"
|
||||||
|
|
||||||
|
# Second identical call should be rejected
|
||||||
|
second_message = AIMessage(content="test_tool('value1', 'value2')")
|
||||||
|
result2 = agent._execute_tool(second_message)
|
||||||
|
assert "Repeat calls of test_tool with the same parameters are not allowed" in result2
|
||||||
|
|
||||||
|
# Call with different positional args should succeed
|
||||||
|
third_message = AIMessage(content="test_tool('value1', 'different')")
|
||||||
|
result3 = agent._execute_tool(third_message)
|
||||||
|
assert result3 == "Tool execution result"
|
||||||
|
|
||||||
|
# Call with same values but as keyword args should be considered different
|
||||||
|
fourth_message = AIMessage(content="test_tool(param1='value1', param2='different')")
|
||||||
|
result4 = agent._execute_tool(fourth_message)
|
||||||
|
assert result4 == "Tool execution result"
|
||||||
|
|
||||||
|
# Mixed positional and keyword args
|
||||||
|
fifth_message = AIMessage(content="test_tool('value1', param2='different')")
|
||||||
|
result5 = agent._execute_tool(fifth_message)
|
||||||
|
assert result5 == "Tool execution result"
|
||||||
|
|
||||||
|
# Repeat of mixed call should be rejected
|
||||||
|
sixth_message = AIMessage(content="test_tool('value1', param2='different')")
|
||||||
|
result6 = agent._execute_tool(sixth_message)
|
||||||
|
assert "Repeat calls of test_tool with the same parameters are not allowed" in result6
|
||||||
Loading…
Reference in New Issue