Extract tool reflection.

This commit is contained in:
AI Christianson 2025-01-09 15:35:14 -05:00
parent 8d0bacdcda
commit 63be5248e1
4 changed files with 128 additions and 29 deletions

View File

@ -1,7 +1,9 @@
import inspect import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Any, Generator, List, Optional, Union from typing import Dict, Any, Generator, List, Optional, Union
import re from typing import Dict, Any, Generator, List, Optional, Union
from ra_aid.tools.reflection import get_function_info
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
from ra_aid.exceptions import ToolExecutionError from ra_aid.exceptions import ToolExecutionError
@ -14,6 +16,24 @@ class ChunkMessage:
content: str content: str
status: str status: str
def validate_function_call_pattern(s: str) -> bool:
"""Check if a string matches the expected function call pattern.
Validates that the string represents a valid function call with:
- Function name consisting of word characters, underscores or hyphens
- Opening/closing parentheses with balanced nesting
- Arbitrary arguments inside parentheses
- Optional whitespace
Args:
s: String to validate
Returns:
bool: False if pattern matches (valid), True if invalid
"""
pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$"
return not re.match(pattern, s, re.DOTALL)
class CiaynAgent: class CiaynAgent:
"""Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction. """Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction.
@ -45,26 +65,6 @@ class CiaynAgent:
- Memory management with configurable limits - Memory management with configurable limits
""" """
@staticmethod
def _does_not_conform_to_pattern(s):
pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$"
return not re.match(pattern, s, re.DOTALL)
def _get_function_info(self, func):
"""
Returns a well-formatted string containing the function signature and docstring,
designed to be easily readable by both humans and LLMs.
"""
signature = inspect.signature(func)
docstring = inspect.getdoc(func)
if docstring is None:
docstring = "No docstring provided"
full_signature = f"{func.__name__}{signature}"
info = f"""{full_signature}
\"\"\"
{docstring}
\"\"\""""
return info
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000): def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000):
"""Initialize the agent with a model and list of tools. """Initialize the agent with a model and list of tools.
@ -81,7 +81,7 @@ class CiaynAgent:
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.available_functions = [] self.available_functions = []
for t in tools: for t in tools:
self.available_functions.append(self._get_function_info(t.func)) self.available_functions.append(get_function_info(t.func))
def _build_prompt(self, last_result: Optional[str] = None) -> str: def _build_prompt(self, last_result: Optional[str] = None) -> str:
"""Build the prompt for the agent including available tools and context.""" """Build the prompt for the agent including available tools and context."""
@ -217,7 +217,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
# code = code.replace("\n", " ") # code = code.replace("\n", " ")
# if the eval fails, try to extract it via a model call # if the eval fails, try to extract it via a model call
if self._does_not_conform_to_pattern(code): if validate_function_call_pattern(code):
functions_list = "\n\n".join(self.available_functions) functions_list = "\n\n".join(self.available_functions)
code = _extract_tool_call(code, functions_list) code = _extract_tool_call(code, functions_list)

View File

@ -0,0 +1,36 @@
"""Functions for extracting and validating tool function signatures and documentation.
This module provides utilities for:
- Extracting function signatures and docstrings via reflection
- Formatting tool information for agent consumption
"""
import inspect
__all__ = ['get_function_info']
def get_function_info(func):
"""Return a well-formatted string containing the function signature and docstring.
Uses Python's inspect module to extract and format function metadata in a way
that is easily readable by both humans and language models.
Args:
func: The function to analyze
Returns:
str: Formatted string containing function name, signature and docstring
"""
signature = inspect.signature(func)
docstring = inspect.getdoc(func)
if docstring is None:
docstring = "No docstring provided"
full_signature = f"{func.__name__}{signature}"
info = f"""{full_signature}
\"\"\"
{docstring}
\"\"\""""
return info

View File

@ -2,6 +2,7 @@ import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from langchain_core.messages import HumanMessage, AIMessage from langchain_core.messages import HumanMessage, AIMessage
from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.agents.ciayn_agent import validate_function_call_pattern
@pytest.fixture @pytest.fixture
def mock_model(): def mock_model():
@ -156,7 +157,7 @@ def test_trim_chat_history_both_limits():
assert result[1] == chat_history[-1] assert result[1] == chat_history[-1]
class TestCiaynAgentRegexValidation: class TestFunctionCallValidation:
@pytest.mark.parametrize("test_input", [ @pytest.mark.parametrize("test_input", [
"basic_func()", "basic_func()",
"func_with_arg(\"test\")", "func_with_arg(\"test\")",
@ -167,7 +168,7 @@ class TestCiaynAgentRegexValidation:
]) ])
def test_valid_function_calls(self, test_input): def test_valid_function_calls(self, test_input):
"""Test function call patterns that should pass validation.""" """Test function call patterns that should pass validation."""
assert not CiaynAgent._does_not_conform_to_pattern(test_input) assert not validate_function_call_pattern(test_input)
@pytest.mark.parametrize("test_input", [ @pytest.mark.parametrize("test_input", [
"", "",
@ -179,7 +180,7 @@ class TestCiaynAgentRegexValidation:
]) ])
def test_invalid_function_calls(self, test_input): def test_invalid_function_calls(self, test_input):
"""Test function call patterns that should fail validation.""" """Test function call patterns that should fail validation."""
assert CiaynAgent._does_not_conform_to_pattern(test_input) assert validate_function_call_pattern(test_input)
@pytest.mark.parametrize("test_input", [ @pytest.mark.parametrize("test_input", [
" leading_space()", " leading_space()",
@ -189,7 +190,7 @@ class TestCiaynAgentRegexValidation:
]) ])
def test_whitespace_handling(self, test_input): def test_whitespace_handling(self, test_input):
"""Test whitespace variations in function calls.""" """Test whitespace variations in function calls."""
assert not CiaynAgent._does_not_conform_to_pattern(test_input) assert not validate_function_call_pattern(test_input)
@pytest.mark.parametrize("test_input", [ @pytest.mark.parametrize("test_input", [
"""multiline( """multiline(
@ -199,4 +200,4 @@ class TestCiaynAgentRegexValidation:
]) ])
def test_multiline_responses(self, test_input): def test_multiline_responses(self, test_input):
"""Test function calls spanning multiple lines.""" """Test function calls spanning multiple lines."""
assert not CiaynAgent._does_not_conform_to_pattern(test_input) assert not validate_function_call_pattern(test_input)

View File

@ -0,0 +1,62 @@
import pytest
from ra_aid.tools.reflection import get_function_info
# Sample functions for testing get_function_info
def simple_func():
"""A simple function with no parameters."""
pass
def typed_func(a: int, b: str = "default") -> bool:
"""A function with type hints and default value.
Args:
a: An integer parameter
b: A string parameter with default
Returns:
bool: Always returns True
"""
return True
def complex_func(pos1, pos2, *args, kw1="default", **kwargs):
"""A function with complex signature."""
pass
def no_docstring_func(x):
pass
class TestGetFunctionInfo:
def test_simple_function_info(self):
"""Test info extraction for simple function."""
info = get_function_info(simple_func)
assert "simple_func()" in info
assert "A simple function with no parameters" in info
def test_typed_function_info(self):
"""Test info extraction for function with type hints."""
info = get_function_info(typed_func)
assert "typed_func" in info
assert "a: int" in info
assert "b: str = 'default'" in info
assert "-> bool" in info
assert "Args:" in info
assert "Returns:" in info
def test_complex_function_info(self):
"""Test info extraction for function with complex signature."""
info = get_function_info(complex_func)
assert "complex_func" in info
assert "pos1" in info
assert "pos2" in info
assert "*args" in info
assert "**kwargs" in info
assert "kw1='default'" in info
assert "A function with complex signature" in info
def test_no_docstring_function(self):
"""Test handling of functions without docstrings."""
info = get_function_info(no_docstring_func)
assert "no_docstring_func" in info
assert "No docstring provided" in info