Extract tool reflection.
This commit is contained in:
parent
8d0bacdcda
commit
63be5248e1
|
|
@ -1,7 +1,9 @@
|
|||
import inspect
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Generator, List, Optional, Union
|
||||
import re
|
||||
from typing import Dict, Any, Generator, List, Optional, Union
|
||||
|
||||
from ra_aid.tools.reflection import get_function_info
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
|
|
@ -14,6 +16,24 @@ class ChunkMessage:
|
|||
content: str
|
||||
status: str
|
||||
|
||||
def validate_function_call_pattern(s: str) -> bool:
|
||||
"""Check if a string matches the expected function call pattern.
|
||||
|
||||
Validates that the string represents a valid function call with:
|
||||
- Function name consisting of word characters, underscores or hyphens
|
||||
- Opening/closing parentheses with balanced nesting
|
||||
- Arbitrary arguments inside parentheses
|
||||
- Optional whitespace
|
||||
|
||||
Args:
|
||||
s: String to validate
|
||||
|
||||
Returns:
|
||||
bool: False if pattern matches (valid), True if invalid
|
||||
"""
|
||||
pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$"
|
||||
return not re.match(pattern, s, re.DOTALL)
|
||||
|
||||
class CiaynAgent:
|
||||
"""Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction.
|
||||
|
||||
|
|
@ -45,26 +65,6 @@ class CiaynAgent:
|
|||
- Memory management with configurable limits
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _does_not_conform_to_pattern(s):
|
||||
pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$"
|
||||
return not re.match(pattern, s, re.DOTALL)
|
||||
|
||||
def _get_function_info(self, func):
|
||||
"""
|
||||
Returns a well-formatted string containing the function signature and docstring,
|
||||
designed to be easily readable by both humans and LLMs.
|
||||
"""
|
||||
signature = inspect.signature(func)
|
||||
docstring = inspect.getdoc(func)
|
||||
if docstring is None:
|
||||
docstring = "No docstring provided"
|
||||
full_signature = f"{func.__name__}{signature}"
|
||||
info = f"""{full_signature}
|
||||
\"\"\"
|
||||
{docstring}
|
||||
\"\"\""""
|
||||
return info
|
||||
|
||||
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000):
|
||||
"""Initialize the agent with a model and list of tools.
|
||||
|
|
@ -81,7 +81,7 @@ class CiaynAgent:
|
|||
self.max_tokens = max_tokens
|
||||
self.available_functions = []
|
||||
for t in tools:
|
||||
self.available_functions.append(self._get_function_info(t.func))
|
||||
self.available_functions.append(get_function_info(t.func))
|
||||
|
||||
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
||||
"""Build the prompt for the agent including available tools and context."""
|
||||
|
|
@ -217,7 +217,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
# code = code.replace("\n", " ")
|
||||
|
||||
# if the eval fails, try to extract it via a model call
|
||||
if self._does_not_conform_to_pattern(code):
|
||||
if validate_function_call_pattern(code):
|
||||
functions_list = "\n\n".join(self.available_functions)
|
||||
code = _extract_tool_call(code, functions_list)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
"""Functions for extracting and validating tool function signatures and documentation.
|
||||
|
||||
This module provides utilities for:
|
||||
- Extracting function signatures and docstrings via reflection
|
||||
- Formatting tool information for agent consumption
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
__all__ = ['get_function_info']
|
||||
|
||||
|
||||
def get_function_info(func):
|
||||
"""Return a well-formatted string containing the function signature and docstring.
|
||||
|
||||
Uses Python's inspect module to extract and format function metadata in a way
|
||||
that is easily readable by both humans and language models.
|
||||
|
||||
Args:
|
||||
func: The function to analyze
|
||||
|
||||
Returns:
|
||||
str: Formatted string containing function name, signature and docstring
|
||||
"""
|
||||
signature = inspect.signature(func)
|
||||
docstring = inspect.getdoc(func)
|
||||
if docstring is None:
|
||||
docstring = "No docstring provided"
|
||||
full_signature = f"{func.__name__}{signature}"
|
||||
info = f"""{full_signature}
|
||||
\"\"\"
|
||||
{docstring}
|
||||
\"\"\""""
|
||||
return info
|
||||
|
||||
|
||||
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
from unittest.mock import Mock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from ra_aid.agents.ciayn_agent import validate_function_call_pattern
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
|
|
@ -156,7 +157,7 @@ def test_trim_chat_history_both_limits():
|
|||
assert result[1] == chat_history[-1]
|
||||
|
||||
|
||||
class TestCiaynAgentRegexValidation:
|
||||
class TestFunctionCallValidation:
|
||||
@pytest.mark.parametrize("test_input", [
|
||||
"basic_func()",
|
||||
"func_with_arg(\"test\")",
|
||||
|
|
@ -167,7 +168,7 @@ class TestCiaynAgentRegexValidation:
|
|||
])
|
||||
def test_valid_function_calls(self, test_input):
|
||||
"""Test function call patterns that should pass validation."""
|
||||
assert not CiaynAgent._does_not_conform_to_pattern(test_input)
|
||||
assert not validate_function_call_pattern(test_input)
|
||||
|
||||
@pytest.mark.parametrize("test_input", [
|
||||
"",
|
||||
|
|
@ -179,7 +180,7 @@ class TestCiaynAgentRegexValidation:
|
|||
])
|
||||
def test_invalid_function_calls(self, test_input):
|
||||
"""Test function call patterns that should fail validation."""
|
||||
assert CiaynAgent._does_not_conform_to_pattern(test_input)
|
||||
assert validate_function_call_pattern(test_input)
|
||||
|
||||
@pytest.mark.parametrize("test_input", [
|
||||
" leading_space()",
|
||||
|
|
@ -189,7 +190,7 @@ class TestCiaynAgentRegexValidation:
|
|||
])
|
||||
def test_whitespace_handling(self, test_input):
|
||||
"""Test whitespace variations in function calls."""
|
||||
assert not CiaynAgent._does_not_conform_to_pattern(test_input)
|
||||
assert not validate_function_call_pattern(test_input)
|
||||
|
||||
@pytest.mark.parametrize("test_input", [
|
||||
"""multiline(
|
||||
|
|
@ -199,4 +200,4 @@ class TestCiaynAgentRegexValidation:
|
|||
])
|
||||
def test_multiline_responses(self, test_input):
|
||||
"""Test function calls spanning multiple lines."""
|
||||
assert not CiaynAgent._does_not_conform_to_pattern(test_input)
|
||||
assert not validate_function_call_pattern(test_input)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
import pytest
|
||||
from ra_aid.tools.reflection import get_function_info
|
||||
|
||||
# Sample functions for testing get_function_info
|
||||
def simple_func():
|
||||
"""A simple function with no parameters."""
|
||||
pass
|
||||
|
||||
def typed_func(a: int, b: str = "default") -> bool:
|
||||
"""A function with type hints and default value.
|
||||
|
||||
Args:
|
||||
a: An integer parameter
|
||||
b: A string parameter with default
|
||||
|
||||
Returns:
|
||||
bool: Always returns True
|
||||
"""
|
||||
return True
|
||||
|
||||
def complex_func(pos1, pos2, *args, kw1="default", **kwargs):
|
||||
"""A function with complex signature."""
|
||||
pass
|
||||
|
||||
def no_docstring_func(x):
|
||||
pass
|
||||
|
||||
class TestGetFunctionInfo:
|
||||
def test_simple_function_info(self):
|
||||
"""Test info extraction for simple function."""
|
||||
info = get_function_info(simple_func)
|
||||
assert "simple_func()" in info
|
||||
assert "A simple function with no parameters" in info
|
||||
|
||||
def test_typed_function_info(self):
|
||||
"""Test info extraction for function with type hints."""
|
||||
info = get_function_info(typed_func)
|
||||
assert "typed_func" in info
|
||||
assert "a: int" in info
|
||||
assert "b: str = 'default'" in info
|
||||
assert "-> bool" in info
|
||||
assert "Args:" in info
|
||||
assert "Returns:" in info
|
||||
|
||||
def test_complex_function_info(self):
|
||||
"""Test info extraction for function with complex signature."""
|
||||
info = get_function_info(complex_func)
|
||||
assert "complex_func" in info
|
||||
assert "pos1" in info
|
||||
assert "pos2" in info
|
||||
assert "*args" in info
|
||||
assert "**kwargs" in info
|
||||
assert "kw1='default'" in info
|
||||
assert "A function with complex signature" in info
|
||||
|
||||
def test_no_docstring_function(self):
|
||||
"""Test handling of functions without docstrings."""
|
||||
info = get_function_info(no_docstring_func)
|
||||
assert "no_docstring_func" in info
|
||||
assert "No docstring provided" in info
|
||||
|
||||
|
||||
Loading…
Reference in New Issue