* add way to extract tool use

* respect aiderignore for listing files

* properly use filetypes in ripgrep
This commit is contained in:
Benedikt Terhechte 2025-01-09 15:46:32 +01:00 committed by GitHub
parent 0449564109
commit bfc0e9c626
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 117 additions and 6 deletions

View File

@ -1,6 +1,7 @@
import inspect import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Any, Generator, List, Optional, Union from typing import Dict, Any, Generator, List, Optional, Union
import re
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
from ra_aid.exceptions import ToolExecutionError from ra_aid.exceptions import ToolExecutionError
@ -143,8 +144,17 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
tool.func.__name__: tool.func tool.func.__name__: tool.func
for tool in self.tools for tool in self.tools
} }
try: try:
code = code.strip()
code = code.replace("\n", " ")
# if the eval fails, try to extract it via a model call
if _does_not_conform_to_pattern(code):
functions_list = "\n\n".join(self.available_functions)
code = _extract_tool_call(code, functions_list)
result = eval(code.strip(), globals_dict) result = eval(code.strip(), globals_dict)
return result return result
except Exception as e: except Exception as e:
@ -250,3 +260,40 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
except ToolExecutionError as e: except ToolExecutionError as e:
chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again.")) chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."))
yield self._create_error_chunk(str(e)) yield self._create_error_chunk(str(e))
def _extract_tool_call(code: str, functions_list: str) -> str:
from ra_aid.tools.expert import get_model
model = get_model()
prompt = f"""
I'm conversing with a AI model and requiring responses in a particular format: A function call with any parameters escaped. Here is an example:
```
run_programming_task("blah \" blah\" blah")
```
The following tasks are allowed:
{functions_list}
I got this invalid response from the model, can you format it so it becomes a correct function call?
```
{code}
```
"""
response = model.invoke(prompt)
response = response.content
pattern = r"([\w_\-]+)\((.*?)\)"
matches = re.findall(pattern, response, re.DOTALL)
if len(matches) == 0:
raise ToolExecutionError("Failed to extract tool call")
ma = matches[0][0].strip()
mb = matches[0][1].strip().replace("\n", " ")
return f"{ma}({mb})"
def _does_not_conform_to_pattern(s):
pattern = r"^\s*[\w_\-]+\((.*?)\)\s*$"
return not re.match(pattern, s, re.DOTALL)

View File

@ -65,13 +65,35 @@ def load_gitignore_patterns(path: Path) -> pathspec.PathSpec:
""" """
gitignore_path = path / '.gitignore' gitignore_path = path / '.gitignore'
patterns = [] patterns = []
def modify_path(p: str) -> str:
# Python pathspec doesn't treat `blah/` as a ignore folder, but `blah`. So we strip them
p = p.strip()
if p.endswith("/"):
return p[:-1]
return p
# Load patterns from .gitignore if it exists # Load patterns from .gitignore if it exists
if gitignore_path.exists(): if gitignore_path.exists():
with open(gitignore_path) as f: with open(gitignore_path) as f:
patterns.extend(line.strip() for line in f patterns.extend(
if line.strip() and not line.startswith('#')) modify_path(line)
for line in f
if line.strip() and not line.startswith("#")
)
# add patterns from .aiderignore if it exists
aiderignore_path = path / ".aiderignore"
# Load patterns from .gitignore if it exists
if aiderignore_path.exists():
with open(aiderignore_path) as f:
patterns.extend(
modify_path(line)
for line in f
if line.strip() and not line.startswith("#")
)
# Add default patterns # Add default patterns
patterns.extend(DEFAULT_EXCLUDE_PATTERNS) patterns.extend(DEFAULT_EXCLUDE_PATTERNS)

View File

@ -24,6 +24,46 @@ DEFAULT_EXCLUDE_DIRS = [
'.vscode' '.vscode'
] ]
FILE_TYPE_MAP = {
# General programming languages
"py": "python",
"rs": "rust",
"js": "javascript",
"ts": "typescript",
"java": "java",
"c": "c",
"cpp": "cpp",
"h": "c-header",
"hpp": "cpp-header",
"cs": "csharp",
"go": "go",
"rb": "ruby",
"php": "php",
"swift": "swift",
"kt": "kotlin",
"sh": "sh",
"bash": "sh",
"r": "r",
"pl": "perl",
"scala": "scala",
"dart": "dart",
# Markup, data, and web
"html": "html",
"htm": "html",
"xml": "xml",
"css": "css",
"scss": "scss",
"json": "json",
"yaml": "yaml",
"yml": "yaml",
"toml": "toml",
"md": "markdown",
"markdown": "markdown",
"sql": "sql",
"psql": "postgres",
}
@tool @tool
def ripgrep_search( def ripgrep_search(
pattern: str, pattern: str,
@ -63,7 +103,9 @@ def ripgrep_search(
cmd.append('--follow') cmd.append('--follow')
if file_type: if file_type:
cmd.extend(['-t', file_type]) if FILE_TYPE_MAP.get(file_type):
file_type = FILE_TYPE_MAP.get(file_type)
cmd.extend(["-t", file_type])
# Add exclusions # Add exclusions
exclusions = DEFAULT_EXCLUDE_DIRS + (exclude_dirs or []) exclusions = DEFAULT_EXCLUDE_DIRS + (exclude_dirs or [])