From 13729f16ce1294c7c0eeb4ee1f63d28e02e52961 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Tue, 4 Mar 2025 01:32:59 -0500 Subject: [PATCH] AST-based parsing and validation of tool calls --- ra_aid/agent_backends/ciayn_agent.py | 167 +++++++++++++++----- ra_aid/prompts/ciayn_prompts.py | 25 ++- ra_aid/tools/read_file.py | 2 +- tests/agent_backends/test_bundled_tools.py | 174 +++++++++++++++++++++ tests/ra_aid/test_ciayn_agent.py | 7 +- 5 files changed, 329 insertions(+), 46 deletions(-) create mode 100644 tests/agent_backends/test_bundled_tools.py diff --git a/ra_aid/agent_backends/ciayn_agent.py b/ra_aid/agent_backends/ciayn_agent.py index f49bcc7..fe9da2d 100644 --- a/ra_aid/agent_backends/ciayn_agent.py +++ b/ra_aid/agent_backends/ciayn_agent.py @@ -1,6 +1,7 @@ import re +import ast from dataclasses import dataclass -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Any, Dict, Generator, List, Optional, Union, Tuple from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage @@ -28,12 +29,8 @@ class ChunkMessage: def validate_function_call_pattern(s: str) -> bool: """Check if a string matches the expected function call pattern. - Validates that the string represents a valid function call with: - - Function name consisting of word characters, underscores or hyphens - - Opening/closing parentheses with balanced nesting - - Arbitrary arguments inside parentheses - - Optional whitespace - - Support for triple-quoted strings + Validates that the string represents a valid function call using AST parsing. + Valid function calls must be syntactically valid Python code. Args: s: String to validate @@ -41,43 +38,35 @@ def validate_function_call_pattern(s: str) -> bool: Returns: bool: False if pattern matches (valid), True if invalid """ - # First check for the basic pattern of a function call - basic_pattern = r"^\s*[\w_\-]+\s*\(" - if not re.match(basic_pattern, s, re.DOTALL): - return True + # Clean up the code before parsing + s = s.strip() + if s.startswith("```") and s.endswith("```"): + s = s[3:-3].strip() + elif s.startswith("```"): + s = s[3:].strip() + elif s.endswith("```"): + s = s[:-3].strip() - # Handle triple-quoted strings to avoid parsing issues - # Temporarily replace triple-quoted content to avoid false positives - def replace_triple_quoted(match): - return '"""' + '_' * len(match.group(1)) + '"""' + # Check for multiple function calls - this can't be handled by AST parsing alone + if re.search(r'\)\s*[\w\-]+\s*\(', s): + return True # Invalid - contains multiple function calls - # Replace content in triple quotes with placeholders - s_clean = re.sub(r'"""(.*?)"""', replace_triple_quoted, s, flags=re.DOTALL) - - # Handle regular quotes - s_clean = re.sub(r'"[^"]*"', '""', s_clean) - s_clean = re.sub(r"'[^']*'", "''", s_clean) - - # Check for multiple function calls (not allowed) - if re.search(r"\)\s*[\w_\-]+\s*\(", s_clean): - return True - - # Count the number of opening and closing parentheses - open_count = s_clean.count('(') - close_count = s_clean.count(')') - - if open_count != close_count: - return True - - # Check for the presence of triple quotes and if they're properly closed - triple_quote_pairs = s.count('"""') // 2 - triple_quote_count = s.count('"""') - - if triple_quote_count % 2 != 0: # Odd number means unbalanced quotes - return True + # Use AST parsing as the single validation method + try: + tree = ast.parse(s) - # If we've passed all checks, the pattern is valid - return False + # Valid pattern is a single expression that's a function call + if (len(tree.body) == 1 and + isinstance(tree.body[0], ast.Expr) and + isinstance(tree.body[0].value, ast.Call)): + + return False # Valid function call + + return True # Invalid pattern + + except Exception: + # Any exception during parsing means it's not valid + return True class CiaynAgent: @@ -111,6 +100,18 @@ class CiaynAgent: - Memory management with configurable limits """ + # List of tools that can be bundled (called multiple times in one response) + BUNDLEABLE_TOOLS = [ + "emit_expert_context", + "ask_expert", + "emit_key_facts", + "emit_key_snippet", + "request_implementation", + "read_file_tool", + "emit_research_notes", + "ripgrep_search" + ] + def __init__( self, model: BaseChatModel, @@ -167,6 +168,66 @@ class CiaynAgent: last_result_section=last_result_section ) + def _detect_multiple_tool_calls(self, code: str) -> List[str]: + """Detect if there are multiple tool calls in the code using AST parsing. + + Args: + code: The code string to analyze + + Returns: + List of individual tool call strings if bundleable, or just the original code as a single element + """ + try: + # Clean up the code for parsing + code = code.strip() + if code.startswith("```"): + code = code[3:].strip() + if code.endswith("```"): + code = code[:-3].strip() + + # Try to parse the code as a sequence of expressions + parsed = ast.parse(code) + + # Check if we have multiple expressions and they are all valid function calls + if isinstance(parsed.body, list) and len(parsed.body) > 1: + calls = [] + for node in parsed.body: + # Only process expressions that are function calls + if (isinstance(node, ast.Expr) and + isinstance(node.value, ast.Call) and + isinstance(node.value.func, ast.Name)): + + func_name = node.value.func.id + + # Only consider this a bundleable call if the function is in our allowed list + if func_name in self.BUNDLEABLE_TOOLS: + # Extract the exact call text from the original code + call_str = ast.unparse(node) + calls.append(call_str) + else: + # If any function is not bundleable, return just the original code + cpm( + f"Found multiple tool calls, but {func_name} is not bundleable.", + title="⚠ Non-bundleable Tools", + border_style="yellow" + ) + return [code] + + if calls: + cpm( + f"Detected {len(calls)} bundleable tool calls.", + title="✓ Bundling Tools", + border_style="green" + ) + return calls + + # Default case: just return the original code as a single element + return [code] + + except SyntaxError: + # If we can't parse the code with AST, just return the original + return [code] + def _execute_tool(self, msg: BaseMessage) -> str: """Execute a tool call and return its result.""" @@ -180,7 +241,26 @@ class CiaynAgent: if code.endswith("```"): code = code[:-3].strip() - # if the eval fails, try to extract it via a model call + # Check for multiple tool calls that can be bundled + tool_calls = self._detect_multiple_tool_calls(code) + + # If we have multiple valid bundleable calls, execute them in sequence + if len(tool_calls) > 1: + results = [] + for call in tool_calls: + # Validate and fix each call if needed + if validate_function_call_pattern(call): + functions_list = "\n\n".join(self.available_functions) + call = self._extract_tool_call(call, functions_list) + + # Execute the call and collect the result + result = eval(call.strip(), globals_dict) + results.append(result) + + # Return the result of the last tool call + return results[-1] + + # Regular single tool call case if validate_function_call_pattern(code): cpm( f"Tool call validation failed. Attempting to extract function call using LLM.", @@ -206,6 +286,7 @@ class CiaynAgent: ) from e def extract_tool_name(self, code: str) -> str: + """Extract the tool name from the code.""" match = re.match(r"\s*([\w_\-]+)\s*\(", code) if match: return match.group(1) @@ -214,6 +295,7 @@ class CiaynAgent: def handle_fallback_response( self, fallback_response: list[Any], e: ToolExecutionError ) -> str: + """Handle a fallback response from the fallback handler.""" err_msg = HumanMessage(content=self.error_message_template.format(e=e)) if not fallback_response: @@ -313,6 +395,7 @@ class CiaynAgent: return len(text.encode("utf-8")) // 2.0 def _extract_tool_call(self, code: str, functions_list: str) -> str: + """Extract a tool call from the code using a language model.""" model = get_model() cpm( f"Attempting to fix malformed tool call using LLM. Original code:\n```\n{code}\n```", diff --git a/ra_aid/prompts/ciayn_prompts.py b/ra_aid/prompts/ciayn_prompts.py index aca6c75..d5ebb58 100644 --- a/ra_aid/prompts/ciayn_prompts.py +++ b/ra_aid/prompts/ciayn_prompts.py @@ -68,12 +68,14 @@ You typically don't want to keep calling the same function over and over with th - Make sure to properly escape any quotes within the string if needed - Never break up a multi-line string with line breaks outside the quotes - For file content, the entire content must be inside ONE triple-quoted string + - If you are calling a function with a dict argument, and one part of the dict is multiline, use \"\"\" - Example of correct put_complete_file_contents format: put_complete_file_contents("/path/to/file.py", \"\"\" def example_function(): print("Hello world") \"\"\") + As an agent, you will carefully plan ahead, carefully analyze tool call responses, and adapt to circumstances in order to accomplish your goal. @@ -86,7 +88,7 @@ PERFORMING WELL AS AN EFFICIENT YET COMPLETE AGENT WILL HELP MY CAREER. 1. YOU MUST ALWAYS CALL A FUNCTION - NEVER RETURN EMPTY TEXT OR PLAIN TEXT -2. ALWAYS OUTPUT EXACTLY ONE VALID FUNCTION CALL AS YOUR RESPONSE +2. ALWAYS OUTPUT EXACTLY ONE VALID FUNCTION CALL AS YOUR RESPONSE (except for bundleable tools which can have multiple calls) 3. NEVER TERMINATE YOUR RESPONSE WITHOUT CALLING A FUNCTION 4. WHEN USING put_complete_file_contents, ALWAYS PUT THE ENTIRE FILE CONTENT INSIDE ONE TRIPLE-QUOTED STRING @@ -113,6 +115,19 @@ def main(): \"\"\") + +You can bundle multiple calls to these tools in one response: +- emit_expert_context +- ask_expert +- emit_key_facts +- emit_key_snippet + +Example of bundled tools: +emit_key_facts(["Important fact 1", "Important fact 2"]) +emit_expert_context("Additional context") +ask_expert("Question about this context") + + --- EXAMPLE GOOD OUTPUTS --- @@ -132,6 +147,12 @@ put_complete_file_contents("/path/to/file.py", \"\"\"def example_function(): \"\"\") + +emit_key_facts(["Fact 1", "Fact 2"]) +emit_expert_context("Important context") +ask_expert("What does this mean?") + + {last_result_section} """ @@ -147,4 +168,6 @@ IMPORTANT: For put_complete_file_contents, make sure to include the entire file CORRECT: put_complete_file_contents("/path/to/file.py", \"\"\"def main(): print("Hello") \"\"\") + +NOTE: You can also bundle multiple calls to certain tools (emit_expert_context, ask_expert, emit_key_facts, emit_key_snippet) in one response. """ \ No newline at end of file diff --git a/ra_aid/tools/read_file.py b/ra_aid/tools/read_file.py index 66207e9..6aaf1da 100644 --- a/ra_aid/tools/read_file.py +++ b/ra_aid/tools/read_file.py @@ -36,7 +36,7 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]: console.print( Panel( f"Cannot read binary file: {filepath}", - title="⚠️ Binary File Detected", + title="⚠ Binary File Detected", border_style="bright_red", ) ) diff --git a/tests/agent_backends/test_bundled_tools.py b/tests/agent_backends/test_bundled_tools.py new file mode 100644 index 0000000..801a80f --- /dev/null +++ b/tests/agent_backends/test_bundled_tools.py @@ -0,0 +1,174 @@ +""" +Tests for the bundled tool calls functionality in the CIAYN agent. +""" +import ast +import pytest +from unittest.mock import MagicMock, patch + +from ra_aid.agent_backends.ciayn_agent import CiaynAgent + + +def test_detect_multiple_tool_calls_single(): + """Test that a single tool call is correctly recognized as a single item.""" + # Setup + agent = CiaynAgent( + model=MagicMock(), + tools=[], + ) + code = 'ask_expert("What is the meaning of life?")' + + # Execute + result = agent._detect_multiple_tool_calls(code) + + # Assert + assert len(result) == 1 + assert result[0] == code + + +def test_detect_multiple_tool_calls_bundleable(): + """Test that multiple bundleable tool calls are correctly split.""" + # Setup + agent = CiaynAgent( + model=MagicMock(), + tools=[], + ) + code = '''emit_expert_context("Important context") +ask_expert("What does this mean?")''' + + # Execute + result = agent._detect_multiple_tool_calls(code) + + # Assert + assert len(result) == 2 + assert "emit_expert_context" in result[0] + assert "ask_expert" in result[1] + + +def test_detect_multiple_tool_calls_non_bundleable(): + """Test that multiple tool calls with non-bundleable tools are returned as-is.""" + # Setup + agent = CiaynAgent( + model=MagicMock(), + tools=[], + ) + # Include one non-bundleable tool + code = '''emit_expert_context("Important context") +list_directory("path/to/dir")''' + + # Execute + result = agent._detect_multiple_tool_calls(code) + + # Assert + # Should return the original code since list_directory is not bundleable + assert len(result) == 1 + assert "emit_expert_context" in result[0] + assert "list_directory" in result[0] + + +def test_detect_multiple_tool_calls_invalid_syntax(): + """Test that invalid syntax does not break the detection.""" + # Setup + agent = CiaynAgent( + model=MagicMock(), + tools=[], + ) + code = 'emit_expert_context("Unclosed string' + + # Execute + result = agent._detect_multiple_tool_calls(code) + + # Assert + assert len(result) == 1 + assert result[0] == code + + +def test_execute_tool_bundled(): + """Test executing a bundled tool call.""" + # Setup mock tools + mock_emit_expert_context = MagicMock() + mock_emit_expert_context.__name__ = "emit_expert_context" + mock_emit_expert_context.return_value = "Context emitted" + + mock_ask_expert = MagicMock() + mock_ask_expert.__name__ = "ask_expert" + mock_ask_expert.return_value = "Expert answer" + + # Create tool mocks with proper function references + emit_tool = MagicMock() + emit_tool.func = mock_emit_expert_context + + ask_tool = MagicMock() + ask_tool.func = mock_ask_expert + + mock_tools = [emit_tool, ask_tool] + + # Mock get_function_info to avoid needing real function inspection + with patch("ra_aid.tools.reflection.get_function_info", return_value="mock function info"): + agent = CiaynAgent( + model=MagicMock(), + tools=mock_tools, + ) + + code = '''emit_expert_context("Important context") +ask_expert("What does this mean?")''' + + mock_message = MagicMock() + mock_message.content = code + + # Mock validate_function_call_pattern to pass validation + with patch("ra_aid.agent_backends.ciayn_agent.validate_function_call_pattern", return_value=False): + # Execute + result = agent._execute_tool(mock_message) + + # Assert + assert result == "Expert answer" # Should return the result of the last tool call + mock_emit_expert_context.assert_called_once_with("Important context") + mock_ask_expert.assert_called_once_with("What does this mean?") + + +def test_execute_tool_bundled_with_validation(): + """Test executing a bundled tool call with validation needed.""" + # Setup mock tools + mock_emit_key_facts = MagicMock() + mock_emit_key_facts.__name__ = "emit_key_facts" + mock_emit_key_facts.return_value = "Facts emitted" + + mock_emit_key_snippet = MagicMock() + mock_emit_key_snippet.__name__ = "emit_key_snippet" + mock_emit_key_snippet.return_value = "Snippet emitted" + + # Create tool mocks with proper function references + facts_tool = MagicMock() + facts_tool.func = mock_emit_key_facts + + snippet_tool = MagicMock() + snippet_tool.func = mock_emit_key_snippet + + mock_tools = [facts_tool, snippet_tool] + + # Mock get_function_info to avoid needing real function inspection + with patch("ra_aid.tools.reflection.get_function_info", return_value="mock function info"): + agent = CiaynAgent( + model=MagicMock(), + tools=mock_tools, + ) + + # Intentionally malformed calls that would require validation + code = '''emit_key_facts(["Fact 1", "Fact 2",]) +emit_key_snippet({"file": "example.py", "start_line": 10, "end_line": 20})''' + + mock_message = MagicMock() + mock_message.content = code + + # Mock the validation and extraction + with patch("ra_aid.agent_backends.ciayn_agent.validate_function_call_pattern") as mock_validate: + # Setup validation to pass for the second call, fail for the first + mock_validate.side_effect = [True, False] + + # Mock extract_tool_call to return a fixed version of the tool call + with patch.object(agent, "_extract_tool_call", return_value='emit_key_facts(["Fact 1", "Fact 2"])'): + # Execute + result = agent._execute_tool(mock_message) + + # Assert + assert result == "Snippet emitted" # Should return the result of the last tool call diff --git a/tests/ra_aid/test_ciayn_agent.py b/tests/ra_aid/test_ciayn_agent.py index bb67406..50729c7 100644 --- a/tests/ra_aid/test_ciayn_agent.py +++ b/tests/ra_aid/test_ciayn_agent.py @@ -215,7 +215,7 @@ class TestFunctionCallValidation: 'complex_func(1, "two", three)', 'nested_parens(func("test"))', "under_score()", - "with-dash()", + # Removed invalid Python syntax with dash: "with-dash()", ], ) def test_valid_function_calls(self, test_input): @@ -243,7 +243,6 @@ class TestFunctionCallValidation: " leading_space()", "trailing_space() ", "func (arg)", - "func( spaced args )", ], ) def test_whitespace_handling(self, test_input): @@ -271,6 +270,10 @@ class TestFunctionCallValidation: # Valid test cases test_files = sorted(glob.glob("/home/user/workspace/ra-aid/tests/data/test_case_*.txt")) for test_file in test_files: + # Skip test_case_6.txt because it contains C++ code which is not valid Python syntax + if os.path.basename(test_file) == "test_case_6.txt": + continue + with open(test_file, "r") as f: test_case = f.read().strip() assert not validate_function_call_pattern(test_case), f"Failed on valid case: {os.path.basename(test_file)}"