From 63be5248e12ab2da735810309de36b433ccf8b45 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Thu, 9 Jan 2025 15:35:14 -0500 Subject: [PATCH] Extract tool reflection. --- ra_aid/agents/ciayn_agent.py | 48 +++++++++---------- ra_aid/tools/reflection.py | 36 ++++++++++++++ tests/ra_aid/agents/test_ciayn_agent.py | 11 +++-- tests/ra_aid/tools/test_reflection.py | 62 +++++++++++++++++++++++++ 4 files changed, 128 insertions(+), 29 deletions(-) create mode 100644 ra_aid/tools/reflection.py create mode 100644 tests/ra_aid/tools/test_reflection.py diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 8fd5422..a1bbd82 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -1,7 +1,9 @@ -import inspect +import re from dataclasses import dataclass from typing import Dict, Any, Generator, List, Optional, Union -import re +from typing import Dict, Any, Generator, List, Optional, Union + +from ra_aid.tools.reflection import get_function_info from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage from ra_aid.exceptions import ToolExecutionError @@ -14,6 +16,24 @@ class ChunkMessage: content: str status: str +def validate_function_call_pattern(s: str) -> bool: + """Check if a string matches the expected function call pattern. + + Validates that the string represents a valid function call with: + - Function name consisting of word characters, underscores or hyphens + - Opening/closing parentheses with balanced nesting + - Arbitrary arguments inside parentheses + - Optional whitespace + + Args: + s: String to validate + + Returns: + bool: False if pattern matches (valid), True if invalid + """ + pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$" + return not re.match(pattern, s, re.DOTALL) + class CiaynAgent: """Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction. @@ -45,26 +65,6 @@ class CiaynAgent: - Memory management with configurable limits """ - @staticmethod - def _does_not_conform_to_pattern(s): - pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$" - return not re.match(pattern, s, re.DOTALL) - - def _get_function_info(self, func): - """ - Returns a well-formatted string containing the function signature and docstring, - designed to be easily readable by both humans and LLMs. - """ - signature = inspect.signature(func) - docstring = inspect.getdoc(func) - if docstring is None: - docstring = "No docstring provided" - full_signature = f"{func.__name__}{signature}" - info = f"""{full_signature} -\"\"\" -{docstring} -\"\"\"""" - return info def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000): """Initialize the agent with a model and list of tools. @@ -81,7 +81,7 @@ class CiaynAgent: self.max_tokens = max_tokens self.available_functions = [] for t in tools: - self.available_functions.append(self._get_function_info(t.func)) + self.available_functions.append(get_function_info(t.func)) def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -217,7 +217,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" # code = code.replace("\n", " ") # if the eval fails, try to extract it via a model call - if self._does_not_conform_to_pattern(code): + if validate_function_call_pattern(code): functions_list = "\n\n".join(self.available_functions) code = _extract_tool_call(code, functions_list) diff --git a/ra_aid/tools/reflection.py b/ra_aid/tools/reflection.py new file mode 100644 index 0000000..2059629 --- /dev/null +++ b/ra_aid/tools/reflection.py @@ -0,0 +1,36 @@ +"""Functions for extracting and validating tool function signatures and documentation. + +This module provides utilities for: +- Extracting function signatures and docstrings via reflection +- Formatting tool information for agent consumption +""" + +import inspect + +__all__ = ['get_function_info'] + + +def get_function_info(func): + """Return a well-formatted string containing the function signature and docstring. + + Uses Python's inspect module to extract and format function metadata in a way + that is easily readable by both humans and language models. + + Args: + func: The function to analyze + + Returns: + str: Formatted string containing function name, signature and docstring + """ + signature = inspect.signature(func) + docstring = inspect.getdoc(func) + if docstring is None: + docstring = "No docstring provided" + full_signature = f"{func.__name__}{signature}" + info = f"""{full_signature} +\"\"\" +{docstring} +\"\"\"""" + return info + + diff --git a/tests/ra_aid/agents/test_ciayn_agent.py b/tests/ra_aid/agents/test_ciayn_agent.py index ead9733..5803412 100644 --- a/tests/ra_aid/agents/test_ciayn_agent.py +++ b/tests/ra_aid/agents/test_ciayn_agent.py @@ -2,6 +2,7 @@ import pytest from unittest.mock import Mock, patch from langchain_core.messages import HumanMessage, AIMessage from ra_aid.agents.ciayn_agent import CiaynAgent +from ra_aid.agents.ciayn_agent import validate_function_call_pattern @pytest.fixture def mock_model(): @@ -156,7 +157,7 @@ def test_trim_chat_history_both_limits(): assert result[1] == chat_history[-1] -class TestCiaynAgentRegexValidation: +class TestFunctionCallValidation: @pytest.mark.parametrize("test_input", [ "basic_func()", "func_with_arg(\"test\")", @@ -167,7 +168,7 @@ class TestCiaynAgentRegexValidation: ]) def test_valid_function_calls(self, test_input): """Test function call patterns that should pass validation.""" - assert not CiaynAgent._does_not_conform_to_pattern(test_input) + assert not validate_function_call_pattern(test_input) @pytest.mark.parametrize("test_input", [ "", @@ -179,7 +180,7 @@ class TestCiaynAgentRegexValidation: ]) def test_invalid_function_calls(self, test_input): """Test function call patterns that should fail validation.""" - assert CiaynAgent._does_not_conform_to_pattern(test_input) + assert validate_function_call_pattern(test_input) @pytest.mark.parametrize("test_input", [ " leading_space()", @@ -189,7 +190,7 @@ class TestCiaynAgentRegexValidation: ]) def test_whitespace_handling(self, test_input): """Test whitespace variations in function calls.""" - assert not CiaynAgent._does_not_conform_to_pattern(test_input) + assert not validate_function_call_pattern(test_input) @pytest.mark.parametrize("test_input", [ """multiline( @@ -199,4 +200,4 @@ class TestCiaynAgentRegexValidation: ]) def test_multiline_responses(self, test_input): """Test function calls spanning multiple lines.""" - assert not CiaynAgent._does_not_conform_to_pattern(test_input) + assert not validate_function_call_pattern(test_input) diff --git a/tests/ra_aid/tools/test_reflection.py b/tests/ra_aid/tools/test_reflection.py new file mode 100644 index 0000000..5f2004f --- /dev/null +++ b/tests/ra_aid/tools/test_reflection.py @@ -0,0 +1,62 @@ +import pytest +from ra_aid.tools.reflection import get_function_info + +# Sample functions for testing get_function_info +def simple_func(): + """A simple function with no parameters.""" + pass + +def typed_func(a: int, b: str = "default") -> bool: + """A function with type hints and default value. + + Args: + a: An integer parameter + b: A string parameter with default + + Returns: + bool: Always returns True + """ + return True + +def complex_func(pos1, pos2, *args, kw1="default", **kwargs): + """A function with complex signature.""" + pass + +def no_docstring_func(x): + pass + +class TestGetFunctionInfo: + def test_simple_function_info(self): + """Test info extraction for simple function.""" + info = get_function_info(simple_func) + assert "simple_func()" in info + assert "A simple function with no parameters" in info + + def test_typed_function_info(self): + """Test info extraction for function with type hints.""" + info = get_function_info(typed_func) + assert "typed_func" in info + assert "a: int" in info + assert "b: str = 'default'" in info + assert "-> bool" in info + assert "Args:" in info + assert "Returns:" in info + + def test_complex_function_info(self): + """Test info extraction for function with complex signature.""" + info = get_function_info(complex_func) + assert "complex_func" in info + assert "pos1" in info + assert "pos2" in info + assert "*args" in info + assert "**kwargs" in info + assert "kw1='default'" in info + assert "A function with complex signature" in info + + def test_no_docstring_function(self): + """Test handling of functions without docstrings.""" + info = get_function_info(no_docstring_func) + assert "no_docstring_func" in info + assert "No docstring provided" in info + +