Changes (#37)
* add way to extract tool use * respect aiderignore for listing files * properly use filetypes in ripgrep
This commit is contained in:
parent
0449564109
commit
bfc0e9c626
|
|
@ -1,6 +1,7 @@
|
|||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Generator, List, Optional, Union
|
||||
import re
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
|
|
@ -145,6 +146,15 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
}
|
||||
|
||||
try:
|
||||
|
||||
code = code.strip()
|
||||
code = code.replace("\n", " ")
|
||||
|
||||
# if the eval fails, try to extract it via a model call
|
||||
if _does_not_conform_to_pattern(code):
|
||||
functions_list = "\n\n".join(self.available_functions)
|
||||
code = _extract_tool_call(code, functions_list)
|
||||
|
||||
result = eval(code.strip(), globals_dict)
|
||||
return result
|
||||
except Exception as e:
|
||||
|
|
@ -250,3 +260,40 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
except ToolExecutionError as e:
|
||||
chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."))
|
||||
yield self._create_error_chunk(str(e))
|
||||
|
||||
def _extract_tool_call(code: str, functions_list: str) -> str:
|
||||
from ra_aid.tools.expert import get_model
|
||||
|
||||
model = get_model()
|
||||
prompt = f"""
|
||||
I'm conversing with a AI model and requiring responses in a particular format: A function call with any parameters escaped. Here is an example:
|
||||
|
||||
```
|
||||
run_programming_task("blah \" blah\" blah")
|
||||
```
|
||||
|
||||
The following tasks are allowed:
|
||||
|
||||
{functions_list}
|
||||
|
||||
I got this invalid response from the model, can you format it so it becomes a correct function call?
|
||||
|
||||
```
|
||||
{code}
|
||||
```
|
||||
"""
|
||||
response = model.invoke(prompt)
|
||||
response = response.content
|
||||
|
||||
pattern = r"([\w_\-]+)\((.*?)\)"
|
||||
matches = re.findall(pattern, response, re.DOTALL)
|
||||
if len(matches) == 0:
|
||||
raise ToolExecutionError("Failed to extract tool call")
|
||||
ma = matches[0][0].strip()
|
||||
mb = matches[0][1].strip().replace("\n", " ")
|
||||
return f"{ma}({mb})"
|
||||
|
||||
|
||||
def _does_not_conform_to_pattern(s):
|
||||
pattern = r"^\s*[\w_\-]+\((.*?)\)\s*$"
|
||||
return not re.match(pattern, s, re.DOTALL)
|
||||
|
|
@ -66,11 +66,33 @@ def load_gitignore_patterns(path: Path) -> pathspec.PathSpec:
|
|||
gitignore_path = path / '.gitignore'
|
||||
patterns = []
|
||||
|
||||
def modify_path(p: str) -> str:
|
||||
# Python pathspec doesn't treat `blah/` as a ignore folder, but `blah`. So we strip them
|
||||
p = p.strip()
|
||||
if p.endswith("/"):
|
||||
return p[:-1]
|
||||
return p
|
||||
|
||||
# Load patterns from .gitignore if it exists
|
||||
if gitignore_path.exists():
|
||||
with open(gitignore_path) as f:
|
||||
patterns.extend(line.strip() for line in f
|
||||
if line.strip() and not line.startswith('#'))
|
||||
patterns.extend(
|
||||
modify_path(line)
|
||||
for line in f
|
||||
if line.strip() and not line.startswith("#")
|
||||
)
|
||||
|
||||
# add patterns from .aiderignore if it exists
|
||||
aiderignore_path = path / ".aiderignore"
|
||||
|
||||
# Load patterns from .gitignore if it exists
|
||||
if aiderignore_path.exists():
|
||||
with open(aiderignore_path) as f:
|
||||
patterns.extend(
|
||||
modify_path(line)
|
||||
for line in f
|
||||
if line.strip() and not line.startswith("#")
|
||||
)
|
||||
|
||||
# Add default patterns
|
||||
patterns.extend(DEFAULT_EXCLUDE_PATTERNS)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,46 @@ DEFAULT_EXCLUDE_DIRS = [
|
|||
'.vscode'
|
||||
]
|
||||
|
||||
|
||||
FILE_TYPE_MAP = {
|
||||
# General programming languages
|
||||
"py": "python",
|
||||
"rs": "rust",
|
||||
"js": "javascript",
|
||||
"ts": "typescript",
|
||||
"java": "java",
|
||||
"c": "c",
|
||||
"cpp": "cpp",
|
||||
"h": "c-header",
|
||||
"hpp": "cpp-header",
|
||||
"cs": "csharp",
|
||||
"go": "go",
|
||||
"rb": "ruby",
|
||||
"php": "php",
|
||||
"swift": "swift",
|
||||
"kt": "kotlin",
|
||||
"sh": "sh",
|
||||
"bash": "sh",
|
||||
"r": "r",
|
||||
"pl": "perl",
|
||||
"scala": "scala",
|
||||
"dart": "dart",
|
||||
# Markup, data, and web
|
||||
"html": "html",
|
||||
"htm": "html",
|
||||
"xml": "xml",
|
||||
"css": "css",
|
||||
"scss": "scss",
|
||||
"json": "json",
|
||||
"yaml": "yaml",
|
||||
"yml": "yaml",
|
||||
"toml": "toml",
|
||||
"md": "markdown",
|
||||
"markdown": "markdown",
|
||||
"sql": "sql",
|
||||
"psql": "postgres",
|
||||
}
|
||||
|
||||
@tool
|
||||
def ripgrep_search(
|
||||
pattern: str,
|
||||
|
|
@ -63,7 +103,9 @@ def ripgrep_search(
|
|||
cmd.append('--follow')
|
||||
|
||||
if file_type:
|
||||
cmd.extend(['-t', file_type])
|
||||
if FILE_TYPE_MAP.get(file_type):
|
||||
file_type = FILE_TYPE_MAP.get(file_type)
|
||||
cmd.extend(["-t", file_type])
|
||||
|
||||
# Add exclusions
|
||||
exclusions = DEFAULT_EXCLUDE_DIRS + (exclude_dirs or [])
|
||||
|
|
|
|||
Loading…
Reference in New Issue