* add way to extract tool use

* respect aiderignore for listing files

* properly use filetypes in ripgrep
This commit is contained in:
Benedikt Terhechte 2025-01-09 15:46:32 +01:00 committed by GitHub
parent 0449564109
commit bfc0e9c626
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 117 additions and 6 deletions

View File

@ -1,6 +1,7 @@
import inspect
from dataclasses import dataclass
from typing import Dict, Any, Generator, List, Optional, Union
import re
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
from ra_aid.exceptions import ToolExecutionError
@ -145,6 +146,15 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
}
try:
code = code.strip()
code = code.replace("\n", " ")
# if the eval fails, try to extract it via a model call
if _does_not_conform_to_pattern(code):
functions_list = "\n\n".join(self.available_functions)
code = _extract_tool_call(code, functions_list)
result = eval(code.strip(), globals_dict)
return result
except Exception as e:
@ -250,3 +260,40 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
except ToolExecutionError as e:
chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."))
yield self._create_error_chunk(str(e))
def _extract_tool_call(code: str, functions_list: str) -> str:
from ra_aid.tools.expert import get_model
model = get_model()
prompt = f"""
I'm conversing with a AI model and requiring responses in a particular format: A function call with any parameters escaped. Here is an example:
```
run_programming_task("blah \" blah\" blah")
```
The following tasks are allowed:
{functions_list}
I got this invalid response from the model, can you format it so it becomes a correct function call?
```
{code}
```
"""
response = model.invoke(prompt)
response = response.content
pattern = r"([\w_\-]+)\((.*?)\)"
matches = re.findall(pattern, response, re.DOTALL)
if len(matches) == 0:
raise ToolExecutionError("Failed to extract tool call")
ma = matches[0][0].strip()
mb = matches[0][1].strip().replace("\n", " ")
return f"{ma}({mb})"
def _does_not_conform_to_pattern(s):
pattern = r"^\s*[\w_\-]+\((.*?)\)\s*$"
return not re.match(pattern, s, re.DOTALL)

View File

@ -66,11 +66,33 @@ def load_gitignore_patterns(path: Path) -> pathspec.PathSpec:
gitignore_path = path / '.gitignore'
patterns = []
def modify_path(p: str) -> str:
# Python pathspec doesn't treat `blah/` as a ignore folder, but `blah`. So we strip them
p = p.strip()
if p.endswith("/"):
return p[:-1]
return p
# Load patterns from .gitignore if it exists
if gitignore_path.exists():
with open(gitignore_path) as f:
patterns.extend(line.strip() for line in f
if line.strip() and not line.startswith('#'))
patterns.extend(
modify_path(line)
for line in f
if line.strip() and not line.startswith("#")
)
# add patterns from .aiderignore if it exists
aiderignore_path = path / ".aiderignore"
# Load patterns from .gitignore if it exists
if aiderignore_path.exists():
with open(aiderignore_path) as f:
patterns.extend(
modify_path(line)
for line in f
if line.strip() and not line.startswith("#")
)
# Add default patterns
patterns.extend(DEFAULT_EXCLUDE_PATTERNS)

View File

@ -24,6 +24,46 @@ DEFAULT_EXCLUDE_DIRS = [
'.vscode'
]
FILE_TYPE_MAP = {
# General programming languages
"py": "python",
"rs": "rust",
"js": "javascript",
"ts": "typescript",
"java": "java",
"c": "c",
"cpp": "cpp",
"h": "c-header",
"hpp": "cpp-header",
"cs": "csharp",
"go": "go",
"rb": "ruby",
"php": "php",
"swift": "swift",
"kt": "kotlin",
"sh": "sh",
"bash": "sh",
"r": "r",
"pl": "perl",
"scala": "scala",
"dart": "dart",
# Markup, data, and web
"html": "html",
"htm": "html",
"xml": "xml",
"css": "css",
"scss": "scss",
"json": "json",
"yaml": "yaml",
"yml": "yaml",
"toml": "toml",
"md": "markdown",
"markdown": "markdown",
"sql": "sql",
"psql": "postgres",
}
@tool
def ripgrep_search(
pattern: str,
@ -63,7 +103,9 @@ def ripgrep_search(
cmd.append('--follow')
if file_type:
cmd.extend(['-t', file_type])
if FILE_TYPE_MAP.get(file_type):
file_type = FILE_TYPE_MAP.get(file_type)
cmd.extend(["-t", file_type])
# Add exclusions
exclusions = DEFAULT_EXCLUDE_DIRS + (exclude_dirs or [])