From bfc0e9c626d5d7b6197404b797f8209804d0d5f0 Mon Sep 17 00:00:00 2001 From: Benedikt Terhechte Date: Thu, 9 Jan 2025 15:46:32 +0100 Subject: [PATCH] Changes (#37) * add way to extract tool use * respect aiderignore for listing files * properly use filetypes in ripgrep --- ra_aid/agents/ciayn_agent.py | 49 +++++++++++++++++++++++++++++++++- ra_aid/tools/list_directory.py | 30 ++++++++++++++++++--- ra_aid/tools/ripgrep.py | 44 +++++++++++++++++++++++++++++- 3 files changed, 117 insertions(+), 6 deletions(-) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 70e5211..9eb6bb3 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -1,6 +1,7 @@ import inspect from dataclasses import dataclass from typing import Dict, Any, Generator, List, Optional, Union +import re from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage from ra_aid.exceptions import ToolExecutionError @@ -143,8 +144,17 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" tool.func.__name__: tool.func for tool in self.tools } - + try: + + code = code.strip() + code = code.replace("\n", " ") + + # if the eval fails, try to extract it via a model call + if _does_not_conform_to_pattern(code): + functions_list = "\n\n".join(self.available_functions) + code = _extract_tool_call(code, functions_list) + result = eval(code.strip(), globals_dict) return result except Exception as e: @@ -250,3 +260,40 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" except ToolExecutionError as e: chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again.")) yield self._create_error_chunk(str(e)) + +def _extract_tool_call(code: str, functions_list: str) -> str: + from ra_aid.tools.expert import get_model + + model = get_model() + prompt = f""" +I'm conversing with a AI model and requiring responses in a particular format: A function call with any parameters escaped. Here is an example: + +``` +run_programming_task("blah \" blah\" blah") +``` + +The following tasks are allowed: + +{functions_list} + +I got this invalid response from the model, can you format it so it becomes a correct function call? + +``` +{code} +``` + """ + response = model.invoke(prompt) + response = response.content + + pattern = r"([\w_\-]+)\((.*?)\)" + matches = re.findall(pattern, response, re.DOTALL) + if len(matches) == 0: + raise ToolExecutionError("Failed to extract tool call") + ma = matches[0][0].strip() + mb = matches[0][1].strip().replace("\n", " ") + return f"{ma}({mb})" + + +def _does_not_conform_to_pattern(s): + pattern = r"^\s*[\w_\-]+\((.*?)\)\s*$" + return not re.match(pattern, s, re.DOTALL) \ No newline at end of file diff --git a/ra_aid/tools/list_directory.py b/ra_aid/tools/list_directory.py index 52a0553..118d7de 100644 --- a/ra_aid/tools/list_directory.py +++ b/ra_aid/tools/list_directory.py @@ -65,13 +65,35 @@ def load_gitignore_patterns(path: Path) -> pathspec.PathSpec: """ gitignore_path = path / '.gitignore' patterns = [] - + + def modify_path(p: str) -> str: + # Python pathspec doesn't treat `blah/` as a ignore folder, but `blah`. So we strip them + p = p.strip() + if p.endswith("/"): + return p[:-1] + return p + # Load patterns from .gitignore if it exists if gitignore_path.exists(): with open(gitignore_path) as f: - patterns.extend(line.strip() for line in f - if line.strip() and not line.startswith('#')) - + patterns.extend( + modify_path(line) + for line in f + if line.strip() and not line.startswith("#") + ) + + # add patterns from .aiderignore if it exists + aiderignore_path = path / ".aiderignore" + + # Load patterns from .gitignore if it exists + if aiderignore_path.exists(): + with open(aiderignore_path) as f: + patterns.extend( + modify_path(line) + for line in f + if line.strip() and not line.startswith("#") + ) + # Add default patterns patterns.extend(DEFAULT_EXCLUDE_PATTERNS) diff --git a/ra_aid/tools/ripgrep.py b/ra_aid/tools/ripgrep.py index dfa534f..cfa293f 100644 --- a/ra_aid/tools/ripgrep.py +++ b/ra_aid/tools/ripgrep.py @@ -24,6 +24,46 @@ DEFAULT_EXCLUDE_DIRS = [ '.vscode' ] + +FILE_TYPE_MAP = { + # General programming languages + "py": "python", + "rs": "rust", + "js": "javascript", + "ts": "typescript", + "java": "java", + "c": "c", + "cpp": "cpp", + "h": "c-header", + "hpp": "cpp-header", + "cs": "csharp", + "go": "go", + "rb": "ruby", + "php": "php", + "swift": "swift", + "kt": "kotlin", + "sh": "sh", + "bash": "sh", + "r": "r", + "pl": "perl", + "scala": "scala", + "dart": "dart", + # Markup, data, and web + "html": "html", + "htm": "html", + "xml": "xml", + "css": "css", + "scss": "scss", + "json": "json", + "yaml": "yaml", + "yml": "yaml", + "toml": "toml", + "md": "markdown", + "markdown": "markdown", + "sql": "sql", + "psql": "postgres", +} + @tool def ripgrep_search( pattern: str, @@ -63,7 +103,9 @@ def ripgrep_search( cmd.append('--follow') if file_type: - cmd.extend(['-t', file_type]) + if FILE_TYPE_MAP.get(file_type): + file_type = FILE_TYPE_MAP.get(file_type) + cmd.extend(["-t", file_type]) # Add exclusions exclusions = DEFAULT_EXCLUDE_DIRS + (exclude_dirs or [])