Extract tool reflection.
This commit is contained in:
parent
8d0bacdcda
commit
63be5248e1
|
|
@ -1,7 +1,9 @@
|
||||||
import inspect
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Any, Generator, List, Optional, Union
|
from typing import Dict, Any, Generator, List, Optional, Union
|
||||||
import re
|
from typing import Dict, Any, Generator, List, Optional, Union
|
||||||
|
|
||||||
|
from ra_aid.tools.reflection import get_function_info
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
|
|
@ -14,6 +16,24 @@ class ChunkMessage:
|
||||||
content: str
|
content: str
|
||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
def validate_function_call_pattern(s: str) -> bool:
|
||||||
|
"""Check if a string matches the expected function call pattern.
|
||||||
|
|
||||||
|
Validates that the string represents a valid function call with:
|
||||||
|
- Function name consisting of word characters, underscores or hyphens
|
||||||
|
- Opening/closing parentheses with balanced nesting
|
||||||
|
- Arbitrary arguments inside parentheses
|
||||||
|
- Optional whitespace
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s: String to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: False if pattern matches (valid), True if invalid
|
||||||
|
"""
|
||||||
|
pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$"
|
||||||
|
return not re.match(pattern, s, re.DOTALL)
|
||||||
|
|
||||||
class CiaynAgent:
|
class CiaynAgent:
|
||||||
"""Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction.
|
"""Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction.
|
||||||
|
|
||||||
|
|
@ -45,26 +65,6 @@ class CiaynAgent:
|
||||||
- Memory management with configurable limits
|
- Memory management with configurable limits
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _does_not_conform_to_pattern(s):
|
|
||||||
pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$"
|
|
||||||
return not re.match(pattern, s, re.DOTALL)
|
|
||||||
|
|
||||||
def _get_function_info(self, func):
|
|
||||||
"""
|
|
||||||
Returns a well-formatted string containing the function signature and docstring,
|
|
||||||
designed to be easily readable by both humans and LLMs.
|
|
||||||
"""
|
|
||||||
signature = inspect.signature(func)
|
|
||||||
docstring = inspect.getdoc(func)
|
|
||||||
if docstring is None:
|
|
||||||
docstring = "No docstring provided"
|
|
||||||
full_signature = f"{func.__name__}{signature}"
|
|
||||||
info = f"""{full_signature}
|
|
||||||
\"\"\"
|
|
||||||
{docstring}
|
|
||||||
\"\"\""""
|
|
||||||
return info
|
|
||||||
|
|
||||||
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000):
|
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000):
|
||||||
"""Initialize the agent with a model and list of tools.
|
"""Initialize the agent with a model and list of tools.
|
||||||
|
|
@ -81,7 +81,7 @@ class CiaynAgent:
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.available_functions = []
|
self.available_functions = []
|
||||||
for t in tools:
|
for t in tools:
|
||||||
self.available_functions.append(self._get_function_info(t.func))
|
self.available_functions.append(get_function_info(t.func))
|
||||||
|
|
||||||
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
||||||
"""Build the prompt for the agent including available tools and context."""
|
"""Build the prompt for the agent including available tools and context."""
|
||||||
|
|
@ -217,7 +217,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
# code = code.replace("\n", " ")
|
# code = code.replace("\n", " ")
|
||||||
|
|
||||||
# if the eval fails, try to extract it via a model call
|
# if the eval fails, try to extract it via a model call
|
||||||
if self._does_not_conform_to_pattern(code):
|
if validate_function_call_pattern(code):
|
||||||
functions_list = "\n\n".join(self.available_functions)
|
functions_list = "\n\n".join(self.available_functions)
|
||||||
code = _extract_tool_call(code, functions_list)
|
code = _extract_tool_call(code, functions_list)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""Functions for extracting and validating tool function signatures and documentation.
|
||||||
|
|
||||||
|
This module provides utilities for:
|
||||||
|
- Extracting function signatures and docstrings via reflection
|
||||||
|
- Formatting tool information for agent consumption
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
__all__ = ['get_function_info']
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_info(func):
|
||||||
|
"""Return a well-formatted string containing the function signature and docstring.
|
||||||
|
|
||||||
|
Uses Python's inspect module to extract and format function metadata in a way
|
||||||
|
that is easily readable by both humans and language models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Formatted string containing function name, signature and docstring
|
||||||
|
"""
|
||||||
|
signature = inspect.signature(func)
|
||||||
|
docstring = inspect.getdoc(func)
|
||||||
|
if docstring is None:
|
||||||
|
docstring = "No docstring provided"
|
||||||
|
full_signature = f"{func.__name__}{signature}"
|
||||||
|
info = f"""{full_signature}
|
||||||
|
\"\"\"
|
||||||
|
{docstring}
|
||||||
|
\"\"\""""
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2,6 +2,7 @@ import pytest
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from langchain_core.messages import HumanMessage, AIMessage
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
|
from ra_aid.agents.ciayn_agent import validate_function_call_pattern
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_model():
|
def mock_model():
|
||||||
|
|
@ -156,7 +157,7 @@ def test_trim_chat_history_both_limits():
|
||||||
assert result[1] == chat_history[-1]
|
assert result[1] == chat_history[-1]
|
||||||
|
|
||||||
|
|
||||||
class TestCiaynAgentRegexValidation:
|
class TestFunctionCallValidation:
|
||||||
@pytest.mark.parametrize("test_input", [
|
@pytest.mark.parametrize("test_input", [
|
||||||
"basic_func()",
|
"basic_func()",
|
||||||
"func_with_arg(\"test\")",
|
"func_with_arg(\"test\")",
|
||||||
|
|
@ -167,7 +168,7 @@ class TestCiaynAgentRegexValidation:
|
||||||
])
|
])
|
||||||
def test_valid_function_calls(self, test_input):
|
def test_valid_function_calls(self, test_input):
|
||||||
"""Test function call patterns that should pass validation."""
|
"""Test function call patterns that should pass validation."""
|
||||||
assert not CiaynAgent._does_not_conform_to_pattern(test_input)
|
assert not validate_function_call_pattern(test_input)
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_input", [
|
@pytest.mark.parametrize("test_input", [
|
||||||
"",
|
"",
|
||||||
|
|
@ -179,7 +180,7 @@ class TestCiaynAgentRegexValidation:
|
||||||
])
|
])
|
||||||
def test_invalid_function_calls(self, test_input):
|
def test_invalid_function_calls(self, test_input):
|
||||||
"""Test function call patterns that should fail validation."""
|
"""Test function call patterns that should fail validation."""
|
||||||
assert CiaynAgent._does_not_conform_to_pattern(test_input)
|
assert validate_function_call_pattern(test_input)
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_input", [
|
@pytest.mark.parametrize("test_input", [
|
||||||
" leading_space()",
|
" leading_space()",
|
||||||
|
|
@ -189,7 +190,7 @@ class TestCiaynAgentRegexValidation:
|
||||||
])
|
])
|
||||||
def test_whitespace_handling(self, test_input):
|
def test_whitespace_handling(self, test_input):
|
||||||
"""Test whitespace variations in function calls."""
|
"""Test whitespace variations in function calls."""
|
||||||
assert not CiaynAgent._does_not_conform_to_pattern(test_input)
|
assert not validate_function_call_pattern(test_input)
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_input", [
|
@pytest.mark.parametrize("test_input", [
|
||||||
"""multiline(
|
"""multiline(
|
||||||
|
|
@ -199,4 +200,4 @@ class TestCiaynAgentRegexValidation:
|
||||||
])
|
])
|
||||||
def test_multiline_responses(self, test_input):
|
def test_multiline_responses(self, test_input):
|
||||||
"""Test function calls spanning multiple lines."""
|
"""Test function calls spanning multiple lines."""
|
||||||
assert not CiaynAgent._does_not_conform_to_pattern(test_input)
|
assert not validate_function_call_pattern(test_input)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
import pytest
|
||||||
|
from ra_aid.tools.reflection import get_function_info
|
||||||
|
|
||||||
|
# Sample functions for testing get_function_info
|
||||||
|
def simple_func():
|
||||||
|
"""A simple function with no parameters."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def typed_func(a: int, b: str = "default") -> bool:
|
||||||
|
"""A function with type hints and default value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: An integer parameter
|
||||||
|
b: A string parameter with default
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Always returns True
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def complex_func(pos1, pos2, *args, kw1="default", **kwargs):
|
||||||
|
"""A function with complex signature."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def no_docstring_func(x):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TestGetFunctionInfo:
|
||||||
|
def test_simple_function_info(self):
|
||||||
|
"""Test info extraction for simple function."""
|
||||||
|
info = get_function_info(simple_func)
|
||||||
|
assert "simple_func()" in info
|
||||||
|
assert "A simple function with no parameters" in info
|
||||||
|
|
||||||
|
def test_typed_function_info(self):
|
||||||
|
"""Test info extraction for function with type hints."""
|
||||||
|
info = get_function_info(typed_func)
|
||||||
|
assert "typed_func" in info
|
||||||
|
assert "a: int" in info
|
||||||
|
assert "b: str = 'default'" in info
|
||||||
|
assert "-> bool" in info
|
||||||
|
assert "Args:" in info
|
||||||
|
assert "Returns:" in info
|
||||||
|
|
||||||
|
def test_complex_function_info(self):
|
||||||
|
"""Test info extraction for function with complex signature."""
|
||||||
|
info = get_function_info(complex_func)
|
||||||
|
assert "complex_func" in info
|
||||||
|
assert "pos1" in info
|
||||||
|
assert "pos2" in info
|
||||||
|
assert "*args" in info
|
||||||
|
assert "**kwargs" in info
|
||||||
|
assert "kw1='default'" in info
|
||||||
|
assert "A function with complex signature" in info
|
||||||
|
|
||||||
|
def test_no_docstring_function(self):
|
||||||
|
"""Test handling of functions without docstrings."""
|
||||||
|
info = get_function_info(no_docstring_func)
|
||||||
|
assert "no_docstring_func" in info
|
||||||
|
assert "No docstring provided" in info
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue