From 094257e0af4252da40acbc26b79e05e1b06d8b41 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Thu, 9 Jan 2025 11:51:08 -0500 Subject: [PATCH] Add unit tests for CiaynAgent._does_not_conform_to_pattern. --- ra_aid/agents/ciayn_agent.py | 17 +++++---- ra_aid/tools/agent.py | 3 ++ tests/pytest.ini | 1 + tests/ra_aid/agents/test_ciayn_agent.py | 46 +++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 7 deletions(-) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 99c0087..8fd5422 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -45,6 +45,11 @@ class CiaynAgent: - Memory management with configurable limits """ + @staticmethod + def _does_not_conform_to_pattern(s): + pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$" + return not re.match(pattern, s, re.DOTALL) + def _get_function_info(self, func): """ Returns a well-formatted string containing the function signature and docstring, @@ -191,6 +196,10 @@ As an agent, you will carefully plan ahead, carefully analyze tool call response We're entrusting you with a lot of autonomy and power, so be efficient and don't mess up. +You have often been criticized for: + +- Making the same function calls over and over, getting stuck in a loop. + DO NOT CLAIM YOU ARE FINISHED UNTIL YOU ACTUALLY ARE! Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" @@ -204,12 +213,11 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" } try: - code = code.strip() # code = code.replace("\n", " ") # if the eval fails, try to extract it via a model call - if _does_not_conform_to_pattern(code): + if self._does_not_conform_to_pattern(code): functions_list = "\n\n".join(self.available_functions) code = _extract_tool_call(code, functions_list) @@ -350,8 +358,3 @@ I got this invalid response from the model, can you format it so it becomes a co ma = matches[0][0].strip() mb = matches[0][1].strip().replace("\n", " ") return f"{ma}({mb})" - - -def _does_not_conform_to_pattern(s): - pattern = r"^\s*[\w_\-]+\((.*?)\)\s*$" - return not re.match(pattern, s, re.DOTALL) \ No newline at end of file diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 3614860..51618cf 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -161,6 +161,9 @@ def request_web_research(query: str) -> ResearchResult: def request_research_and_implementation(query: str) -> Dict[str, Any]: """Spawn a research agent to investigate and implement the given query. + If you are calling this on behalf of a user request, you must *faithfully* + represent all info the user gave you, sometimes even to the point of repeating the user query verbatim. + Args: query: The research question or project description """ diff --git a/tests/pytest.ini b/tests/pytest.ini index b92be4e..c484eef 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,2 +1,3 @@ [pytest] markers = + parametrize: Mark test to run with different parameters diff --git a/tests/ra_aid/agents/test_ciayn_agent.py b/tests/ra_aid/agents/test_ciayn_agent.py index 71bc1a8..ead9733 100644 --- a/tests/ra_aid/agents/test_ciayn_agent.py +++ b/tests/ra_aid/agents/test_ciayn_agent.py @@ -154,3 +154,49 @@ def test_trim_chat_history_both_limits(): assert len(result) == 2 # Initial message + 1 message under token limit assert result[0] == initial_messages[0] assert result[1] == chat_history[-1] + + +class TestCiaynAgentRegexValidation: + @pytest.mark.parametrize("test_input", [ + "basic_func()", + "func_with_arg(\"test\")", + "complex_func(1, \"two\", three)", + "nested_parens(func(\"test\"))", + "under_score()", + "with-dash()" + ]) + def test_valid_function_calls(self, test_input): + """Test function call patterns that should pass validation.""" + assert not CiaynAgent._does_not_conform_to_pattern(test_input) + + @pytest.mark.parametrize("test_input", [ + "", + "Invalid!function()", + "missing_parens", + "unmatched(parens))", + "multiple()calls()", + "no spaces()()" + ]) + def test_invalid_function_calls(self, test_input): + """Test function call patterns that should fail validation.""" + assert CiaynAgent._does_not_conform_to_pattern(test_input) + + @pytest.mark.parametrize("test_input", [ + " leading_space()", + "trailing_space() ", + "func (arg)", + "func( spaced args )" + ]) + def test_whitespace_handling(self, test_input): + """Test whitespace variations in function calls.""" + assert not CiaynAgent._does_not_conform_to_pattern(test_input) + + @pytest.mark.parametrize("test_input", [ + """multiline( + arg + )""", + "func(\n arg1,\n arg2\n)" + ]) + def test_multiline_responses(self, test_input): + """Test function calls spanning multiple lines.""" + assert not CiaynAgent._does_not_conform_to_pattern(test_input)