Extract tool reflection.

This commit is contained in:
AI Christianson 2025-01-09 15:35:14 -05:00
parent 8d0bacdcda
commit 63be5248e1
4 changed files with 128 additions and 29 deletions

View File

@ -1,7 +1,9 @@
import inspect
import re
from dataclasses import dataclass
from typing import Dict, Any, Generator, List, Optional, Union
import re
from typing import Dict, Any, Generator, List, Optional, Union
from ra_aid.tools.reflection import get_function_info
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
from ra_aid.exceptions import ToolExecutionError
@ -14,6 +16,24 @@ class ChunkMessage:
content: str
status: str
def validate_function_call_pattern(s: str) -> bool:
"""Check if a string matches the expected function call pattern.
Validates that the string represents a valid function call with:
- Function name consisting of word characters, underscores or hyphens
- Opening/closing parentheses with balanced nesting
- Arbitrary arguments inside parentheses
- Optional whitespace
Args:
s: String to validate
Returns:
bool: False if pattern matches (valid), True if invalid
"""
pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$"
return not re.match(pattern, s, re.DOTALL)
class CiaynAgent:
"""Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction.
@ -45,26 +65,6 @@ class CiaynAgent:
- Memory management with configurable limits
"""
@staticmethod
def _does_not_conform_to_pattern(s):
pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$"
return not re.match(pattern, s, re.DOTALL)
def _get_function_info(self, func):
"""
Returns a well-formatted string containing the function signature and docstring,
designed to be easily readable by both humans and LLMs.
"""
signature = inspect.signature(func)
docstring = inspect.getdoc(func)
if docstring is None:
docstring = "No docstring provided"
full_signature = f"{func.__name__}{signature}"
info = f"""{full_signature}
\"\"\"
{docstring}
\"\"\""""
return info
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000):
"""Initialize the agent with a model and list of tools.
@ -81,7 +81,7 @@ class CiaynAgent:
self.max_tokens = max_tokens
self.available_functions = []
for t in tools:
self.available_functions.append(self._get_function_info(t.func))
self.available_functions.append(get_function_info(t.func))
def _build_prompt(self, last_result: Optional[str] = None) -> str:
"""Build the prompt for the agent including available tools and context."""
@ -217,7 +217,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
# code = code.replace("\n", " ")
# if the eval fails, try to extract it via a model call
if self._does_not_conform_to_pattern(code):
if validate_function_call_pattern(code):
functions_list = "\n\n".join(self.available_functions)
code = _extract_tool_call(code, functions_list)

View File

@ -0,0 +1,36 @@
"""Functions for extracting and validating tool function signatures and documentation.
This module provides utilities for:
- Extracting function signatures and docstrings via reflection
- Formatting tool information for agent consumption
"""
import inspect
__all__ = ['get_function_info']
def get_function_info(func):
"""Return a well-formatted string containing the function signature and docstring.
Uses Python's inspect module to extract and format function metadata in a way
that is easily readable by both humans and language models.
Args:
func: The function to analyze
Returns:
str: Formatted string containing function name, signature and docstring
"""
signature = inspect.signature(func)
docstring = inspect.getdoc(func)
if docstring is None:
docstring = "No docstring provided"
full_signature = f"{func.__name__}{signature}"
info = f"""{full_signature}
\"\"\"
{docstring}
\"\"\""""
return info

View File

@ -2,6 +2,7 @@ import pytest
from unittest.mock import Mock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.agents.ciayn_agent import validate_function_call_pattern
@pytest.fixture
def mock_model():
@ -156,7 +157,7 @@ def test_trim_chat_history_both_limits():
assert result[1] == chat_history[-1]
class TestCiaynAgentRegexValidation:
class TestFunctionCallValidation:
@pytest.mark.parametrize("test_input", [
"basic_func()",
"func_with_arg(\"test\")",
@ -167,7 +168,7 @@ class TestCiaynAgentRegexValidation:
])
def test_valid_function_calls(self, test_input):
"""Test function call patterns that should pass validation."""
assert not CiaynAgent._does_not_conform_to_pattern(test_input)
assert not validate_function_call_pattern(test_input)
@pytest.mark.parametrize("test_input", [
"",
@ -179,7 +180,7 @@ class TestCiaynAgentRegexValidation:
])
def test_invalid_function_calls(self, test_input):
"""Test function call patterns that should fail validation."""
assert CiaynAgent._does_not_conform_to_pattern(test_input)
assert validate_function_call_pattern(test_input)
@pytest.mark.parametrize("test_input", [
" leading_space()",
@ -189,7 +190,7 @@ class TestCiaynAgentRegexValidation:
])
def test_whitespace_handling(self, test_input):
"""Test whitespace variations in function calls."""
assert not CiaynAgent._does_not_conform_to_pattern(test_input)
assert not validate_function_call_pattern(test_input)
@pytest.mark.parametrize("test_input", [
"""multiline(
@ -199,4 +200,4 @@ class TestCiaynAgentRegexValidation:
])
def test_multiline_responses(self, test_input):
"""Test function calls spanning multiple lines."""
assert not CiaynAgent._does_not_conform_to_pattern(test_input)
assert not validate_function_call_pattern(test_input)

View File

@ -0,0 +1,62 @@
import pytest
from ra_aid.tools.reflection import get_function_info
# Sample functions for testing get_function_info
def simple_func():
"""A simple function with no parameters."""
pass
def typed_func(a: int, b: str = "default") -> bool:
"""A function with type hints and default value.
Args:
a: An integer parameter
b: A string parameter with default
Returns:
bool: Always returns True
"""
return True
def complex_func(pos1, pos2, *args, kw1="default", **kwargs):
"""A function with complex signature."""
pass
def no_docstring_func(x):
pass
class TestGetFunctionInfo:
def test_simple_function_info(self):
"""Test info extraction for simple function."""
info = get_function_info(simple_func)
assert "simple_func()" in info
assert "A simple function with no parameters" in info
def test_typed_function_info(self):
"""Test info extraction for function with type hints."""
info = get_function_info(typed_func)
assert "typed_func" in info
assert "a: int" in info
assert "b: str = 'default'" in info
assert "-> bool" in info
assert "Args:" in info
assert "Returns:" in info
def test_complex_function_info(self):
"""Test info extraction for function with complex signature."""
info = get_function_info(complex_func)
assert "complex_func" in info
assert "pos1" in info
assert "pos2" in info
assert "*args" in info
assert "**kwargs" in info
assert "kw1='default'" in info
assert "A function with complex signature" in info
def test_no_docstring_function(self):
"""Test handling of functions without docstrings."""
info = get_function_info(no_docstring_func)
assert "no_docstring_func" in info
assert "No docstring provided" in info