Changes (#37)
* add way to extract tool use * respect aiderignore for listing files * properly use filetypes in ripgrep
This commit is contained in:
parent
0449564109
commit
bfc0e9c626
|
|
@ -1,6 +1,7 @@
|
||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Any, Generator, List, Optional, Union
|
from typing import Dict, Any, Generator, List, Optional, Union
|
||||||
|
import re
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
|
|
@ -143,8 +144,17 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
tool.func.__name__: tool.func
|
tool.func.__name__: tool.func
|
||||||
for tool in self.tools
|
for tool in self.tools
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
code = code.strip()
|
||||||
|
code = code.replace("\n", " ")
|
||||||
|
|
||||||
|
# if the eval fails, try to extract it via a model call
|
||||||
|
if _does_not_conform_to_pattern(code):
|
||||||
|
functions_list = "\n\n".join(self.available_functions)
|
||||||
|
code = _extract_tool_call(code, functions_list)
|
||||||
|
|
||||||
result = eval(code.strip(), globals_dict)
|
result = eval(code.strip(), globals_dict)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -250,3 +260,40 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
except ToolExecutionError as e:
|
except ToolExecutionError as e:
|
||||||
chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."))
|
chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."))
|
||||||
yield self._create_error_chunk(str(e))
|
yield self._create_error_chunk(str(e))
|
||||||
|
|
||||||
|
def _extract_tool_call(code: str, functions_list: str) -> str:
|
||||||
|
from ra_aid.tools.expert import get_model
|
||||||
|
|
||||||
|
model = get_model()
|
||||||
|
prompt = f"""
|
||||||
|
I'm conversing with a AI model and requiring responses in a particular format: A function call with any parameters escaped. Here is an example:
|
||||||
|
|
||||||
|
```
|
||||||
|
run_programming_task("blah \" blah\" blah")
|
||||||
|
```
|
||||||
|
|
||||||
|
The following tasks are allowed:
|
||||||
|
|
||||||
|
{functions_list}
|
||||||
|
|
||||||
|
I got this invalid response from the model, can you format it so it becomes a correct function call?
|
||||||
|
|
||||||
|
```
|
||||||
|
{code}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
response = model.invoke(prompt)
|
||||||
|
response = response.content
|
||||||
|
|
||||||
|
pattern = r"([\w_\-]+)\((.*?)\)"
|
||||||
|
matches = re.findall(pattern, response, re.DOTALL)
|
||||||
|
if len(matches) == 0:
|
||||||
|
raise ToolExecutionError("Failed to extract tool call")
|
||||||
|
ma = matches[0][0].strip()
|
||||||
|
mb = matches[0][1].strip().replace("\n", " ")
|
||||||
|
return f"{ma}({mb})"
|
||||||
|
|
||||||
|
|
||||||
|
def _does_not_conform_to_pattern(s):
|
||||||
|
pattern = r"^\s*[\w_\-]+\((.*?)\)\s*$"
|
||||||
|
return not re.match(pattern, s, re.DOTALL)
|
||||||
|
|
@ -65,13 +65,35 @@ def load_gitignore_patterns(path: Path) -> pathspec.PathSpec:
|
||||||
"""
|
"""
|
||||||
gitignore_path = path / '.gitignore'
|
gitignore_path = path / '.gitignore'
|
||||||
patterns = []
|
patterns = []
|
||||||
|
|
||||||
|
def modify_path(p: str) -> str:
|
||||||
|
# Python pathspec doesn't treat `blah/` as a ignore folder, but `blah`. So we strip them
|
||||||
|
p = p.strip()
|
||||||
|
if p.endswith("/"):
|
||||||
|
return p[:-1]
|
||||||
|
return p
|
||||||
|
|
||||||
# Load patterns from .gitignore if it exists
|
# Load patterns from .gitignore if it exists
|
||||||
if gitignore_path.exists():
|
if gitignore_path.exists():
|
||||||
with open(gitignore_path) as f:
|
with open(gitignore_path) as f:
|
||||||
patterns.extend(line.strip() for line in f
|
patterns.extend(
|
||||||
if line.strip() and not line.startswith('#'))
|
modify_path(line)
|
||||||
|
for line in f
|
||||||
|
if line.strip() and not line.startswith("#")
|
||||||
|
)
|
||||||
|
|
||||||
|
# add patterns from .aiderignore if it exists
|
||||||
|
aiderignore_path = path / ".aiderignore"
|
||||||
|
|
||||||
|
# Load patterns from .gitignore if it exists
|
||||||
|
if aiderignore_path.exists():
|
||||||
|
with open(aiderignore_path) as f:
|
||||||
|
patterns.extend(
|
||||||
|
modify_path(line)
|
||||||
|
for line in f
|
||||||
|
if line.strip() and not line.startswith("#")
|
||||||
|
)
|
||||||
|
|
||||||
# Add default patterns
|
# Add default patterns
|
||||||
patterns.extend(DEFAULT_EXCLUDE_PATTERNS)
|
patterns.extend(DEFAULT_EXCLUDE_PATTERNS)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,46 @@ DEFAULT_EXCLUDE_DIRS = [
|
||||||
'.vscode'
|
'.vscode'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
FILE_TYPE_MAP = {
|
||||||
|
# General programming languages
|
||||||
|
"py": "python",
|
||||||
|
"rs": "rust",
|
||||||
|
"js": "javascript",
|
||||||
|
"ts": "typescript",
|
||||||
|
"java": "java",
|
||||||
|
"c": "c",
|
||||||
|
"cpp": "cpp",
|
||||||
|
"h": "c-header",
|
||||||
|
"hpp": "cpp-header",
|
||||||
|
"cs": "csharp",
|
||||||
|
"go": "go",
|
||||||
|
"rb": "ruby",
|
||||||
|
"php": "php",
|
||||||
|
"swift": "swift",
|
||||||
|
"kt": "kotlin",
|
||||||
|
"sh": "sh",
|
||||||
|
"bash": "sh",
|
||||||
|
"r": "r",
|
||||||
|
"pl": "perl",
|
||||||
|
"scala": "scala",
|
||||||
|
"dart": "dart",
|
||||||
|
# Markup, data, and web
|
||||||
|
"html": "html",
|
||||||
|
"htm": "html",
|
||||||
|
"xml": "xml",
|
||||||
|
"css": "css",
|
||||||
|
"scss": "scss",
|
||||||
|
"json": "json",
|
||||||
|
"yaml": "yaml",
|
||||||
|
"yml": "yaml",
|
||||||
|
"toml": "toml",
|
||||||
|
"md": "markdown",
|
||||||
|
"markdown": "markdown",
|
||||||
|
"sql": "sql",
|
||||||
|
"psql": "postgres",
|
||||||
|
}
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def ripgrep_search(
|
def ripgrep_search(
|
||||||
pattern: str,
|
pattern: str,
|
||||||
|
|
@ -63,7 +103,9 @@ def ripgrep_search(
|
||||||
cmd.append('--follow')
|
cmd.append('--follow')
|
||||||
|
|
||||||
if file_type:
|
if file_type:
|
||||||
cmd.extend(['-t', file_type])
|
if FILE_TYPE_MAP.get(file_type):
|
||||||
|
file_type = FILE_TYPE_MAP.get(file_type)
|
||||||
|
cmd.extend(["-t", file_type])
|
||||||
|
|
||||||
# Add exclusions
|
# Add exclusions
|
||||||
exclusions = DEFAULT_EXCLUDE_DIRS + (exclude_dirs or [])
|
exclusions = DEFAULT_EXCLUDE_DIRS + (exclude_dirs or [])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue