RA.Aid/scripts/generate_swebench_dataset.py

394 lines
13 KiB
Python
Executable File

"""
Script to generate predictions for SWE-bench Lite (princeton-nlp/SWE-bench_Lite).
This version uses 'uv venv' and 'uv pip' / 'uv run ra-aid' commands to manage everything in the environment.
It:
- Loads the SWE-bench Lite dataset
- For each instance, clones (or reuses) the repo at the specified commit
- Creates or reuses a dedicated Python virtual environment via `uv venv`
- Installs `ra-aid` in editable mode + any project dependencies via `uv pip`
- Also installs the cloned project itself in editable mode if it appears to be a Python package
- Calls `uv run ra-aid` to generate a patch
- Writes out predictions in JSON format
No progress bar or spinner is used, allowing `ra-aid` output to stream directly.
"""
import argparse
import json
import logging
import shutil
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple, Dict, Any, List
from git import Repo
from rich.logging import RichHandler
# If you'd like to override Python versions for specific repos:
PYTHON_VERSION_OVERRIDES = {
# "someorg/somerepo": "3.9",
}
def setup_logging(log_dir: Path, verbose: bool = False) -> None:
"""Configure logging with both file and console handlers."""
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / "generate_dataset.log"
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG if verbose else logging.INFO)
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
file_handler.setFormatter(file_formatter)
root_logger.addHandler(file_handler)
console_handler = RichHandler(
rich_tracebacks=True,
show_time=False,
show_path=False
)
console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
root_logger.addHandler(console_handler)
def load_dataset_safely() -> Optional[Any]:
"""Load SWE-bench Lite dataset with error handling."""
try:
from datasets import load_dataset
dataset = load_dataset("princeton-nlp/SWE-bench_Lite")
return dataset
except Exception as e:
logging.error(f"Failed to load dataset: {e}")
return None
def create_output_dirs() -> Tuple[Path, Path]:
"""Create base/log directory structure."""
date_str = datetime.now().strftime("%Y%m%d")
base_dir = Path("evaluation") / "default" / f"{date_str}_raaid"
log_dir = base_dir / "logs"
base_dir.mkdir(parents=True, exist_ok=True)
log_dir.mkdir(parents=True, exist_ok=True)
return base_dir, log_dir
def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
"""
Create (or reuse) a .venv in 'repo_dir' using 'uv venv'.
If force_venv is True, we remove .venv first.
Example command:
uv venv .venv --python=3.9
"""
venv_dir = repo_dir / ".venv"
if venv_dir.exists() and force_venv:
logging.info(f"Removing existing .venv at {venv_dir}")
shutil.rmtree(venv_dir)
python_version = PYTHON_VERSION_OVERRIDES.get(repo_name, None) or "3.12"
cmd = ["uv", "venv"]
if python_version:
cmd.append(f"--python={python_version}")
cmd.append(".venv")
try:
subprocess.run(cmd, cwd=repo_dir, check=True)
except Exception as e:
logging.error(f"Failed to create venv in {repo_dir}: {e}")
def uv_pip_install(repo_dir: Path, args: List[str]) -> None:
"""
Run 'uv pip install ...' in the specified repo_dir.
Example: uv_pip_install(repo_dir, ["--upgrade", "pip"])
"""
cmd = ["uv", "pip", "install"] + args
try:
subprocess.run(cmd, cwd=repo_dir, check=True)
except Exception as e:
logging.error(f"Failed to run uv pip install {args}: {e}")
def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
"""
Call 'uv run ra-aid' with the given prompt in the environment,
streaming output directly to the console (capture_output=False).
Returns the patch if successful, else None.
"""
cmd = [
"uv", "run", "ra-aid",
"--cowboy-mode",
"-m", prompt
]
# We are NOT capturing output, so it streams live:
try:
result = subprocess.run(
cmd,
cwd=repo_dir,
text=True,
check=False, # We manually handle exit code
)
if result.returncode != 0:
logging.error("ra-aid returned non-zero exit code.")
return None
except subprocess.TimeoutExpired:
logging.error("ra-aid timed out")
return None
except Exception as e:
logging.error(f"ra-aid error: {e}")
return None
# Collect patch
patch = get_git_patch(repo_dir)
return patch
def get_git_patch(repo_dir: Path) -> Optional[str]:
"""Generate a git patch from the current changes in `repo_dir`."""
try:
repo = Repo(repo_dir)
if not repo.is_dirty():
logging.info("No changes detected in repository.")
return None
patch_text = repo.git.diff(unified=3)
if not patch_text.strip():
return None
if not any(line.startswith('+') for line in patch_text.splitlines()):
return None
return patch_text
except Exception as e:
logging.error(f"Failed to generate patch: {e}")
return None
def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
"""
- uv venv .venv --python=xxx (optional)
- uv pip install --upgrade pip
- uv pip install --upgrade setuptools wheel (so pkg_resources etc. are available)
- uv pip install -e <ra-aid local path>
- If pyproject.toml -> uv pip install .
- If requirements.txt -> uv pip install -r requirements.txt
- If requirements-dev.txt -> uv pip install -r requirements-dev.txt
- If there's a setup.py or pyproject => uv pip install -e .
"""
uv_venv(repo_dir, repo_name, force_venv)
# 1) upgrade pip
uv_pip_install(repo_dir, ["--upgrade", "pip"])
# 2) ensure setuptools & wheel are installed/up to date
uv_pip_install(repo_dir, ["--upgrade", "setuptools", "wheel"])
# 3) install ra-aid from local path
script_dir = Path(__file__).resolve().parent
ra_aid_root = script_dir.parent # one level up from scripts
uv_pip_install(repo_dir, ["-e", str(ra_aid_root)])
# 4) optional pyproject
pyproject_path = repo_dir / "pyproject.toml"
if pyproject_path.is_file():
uv_pip_install(repo_dir, ["."])
# 5) optional requirements.txt
req_file = repo_dir / "requirements.txt"
if req_file.is_file():
uv_pip_install(repo_dir, ["-r", "requirements.txt"])
# 6) optional requirements-dev.txt
req_dev_file = repo_dir / "requirements-dev.txt"
if req_dev_file.is_file():
uv_pip_install(repo_dir, ["-r", "requirements-dev.txt"])
# 7) install the cloned project in editable mode if it's a Python package
setup_path = repo_dir / "setup.py"
if pyproject_path.is_file() or setup_path.is_file():
logging.info("Installing cloned project in editable mode.")
uv_pip_install(repo_dir, ["-e", "."])
def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List[str]) -> str:
"""
Construct the prompt text from problem_statement, FAIL_TO_PASS, PASS_TO_PASS.
"""
prompt = f"{problem_statement}\n\nTests that need to be fixed:\n```\n"
for t in fail_tests:
prompt += f"- {t}\n"
prompt += "```\n\n"
if pass_tests:
prompt += "Tests that must remain passing:\n```\n"
for t in pass_tests:
prompt += f"- {t}\n"
prompt += "```\n\n"
prompt += "\n\nYou must run all above tests both **before and after** making changes, and ensure they pass as you do your work. Do not write any new test cases."
return prompt
def process_instance(
instance: Dict[str, Any],
projects_dir: Path,
reuse_repo: bool,
force_venv: bool
) -> Dict[str, Any]:
"""
Process a single dataset instance without a progress bar/spinner.
- Clone or reuse the repo at projects_dir/<instance_id>
- Checkout commit
- Create or reuse a .venv in that repo
- Install ra-aid + any project dependencies
- Build prompt, run ra-aid (output streamed to console)
- Return prediction dict
"""
inst_id = instance.get("instance_id", "<unknown>")
repo_name = instance["repo"]
commit = instance["base_commit"]
problem_statement = instance["problem_statement"]
fail_tests = instance.get("FAIL_TO_PASS", [])
pass_tests = instance.get("PASS_TO_PASS", [])
if isinstance(fail_tests, str):
fail_tests = [fail_tests]
if isinstance(pass_tests, str):
pass_tests = [pass_tests]
if "github.com" not in repo_name:
repo_url = f"https://github.com/{repo_name}.git"
else:
repo_url = repo_name
checkout_dir = projects_dir / f"{inst_id}"
try:
if not checkout_dir.exists():
logging.info(f"Cloning {repo_url} -> {checkout_dir}")
repo = Repo.clone_from(repo_url, checkout_dir)
else:
if reuse_repo:
logging.info(f"Reusing existing directory: {checkout_dir}")
repo = Repo(checkout_dir)
else:
logging.info(f"Deleting existing directory: {checkout_dir}")
shutil.rmtree(checkout_dir)
repo = Repo.clone_from(repo_url, checkout_dir)
# checkout correct commit
repo.git.checkout(commit)
# set up venv + deps
setup_venv_and_deps(checkout_dir, repo_name, force_venv)
# build prompt, run ra-aid
prompt_text = build_prompt(problem_statement, fail_tests, pass_tests)
patch = uv_run_raaid(checkout_dir, prompt_text)
return {
"instance_id": inst_id,
"model_patch": patch if patch else "",
"model_name_or_path": "ra-aid"
}
except Exception as e:
logging.error(f"Failed to process {repo_url}:{commit} - {e}")
return {
"instance_id": inst_id,
"model_patch": "",
"model_name_or_path": "ra-aid"
}
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate predictions for SWE-bench Lite using uv + ra-aid (no progress bar)."
)
parser.add_argument(
"output_dir",
type=Path,
help="Directory to store prediction file"
)
parser.add_argument(
"--projects-dir",
type=Path,
required=True,
help="Directory where projects will be cloned."
)
parser.add_argument(
"--num-instances",
type=int,
default=None,
help="Number of instances to process (default: all)"
)
parser.add_argument(
"--reuse-repo",
action="store_true",
help="If set, do not delete an existing repo directory. We'll reuse it."
)
parser.add_argument(
"--force-venv",
action="store_true",
help="If set, recreate the .venv even if it exists."
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose logging"
)
args = parser.parse_args()
from datasets import load_dataset
# Create base/log dirs and set up logging
base_dir, log_dir = create_output_dirs()
setup_logging(log_dir, args.verbose)
logging.info("Starting script")
# Ensure projects dir
args.projects_dir.mkdir(parents=True, exist_ok=True)
# Load dataset
dataset = load_dataset_safely()
if dataset is None:
sys.exit(1)
# Combine dev + test
all_data = list(dataset["dev"]) + list(dataset["test"])
# Ensure output dir
args.output_dir.mkdir(parents=True, exist_ok=True)
predictions_file = args.output_dir / "predictions.json"
predictions: List[Dict[str, str]] = []
limit = args.num_instances if args.num_instances else len(all_data)
# Just a simple for loop - no progress bar
logging.info(f"Processing up to {limit} instances.")
for i, inst in enumerate(all_data):
if i >= limit:
break
logging.info(f"=== Instance {i+1}/{limit}, ID={inst.get('instance_id')} ===")
pred = process_instance(inst, args.projects_dir, args.reuse_repo, args.force_venv)
predictions.append(pred)
# Save predictions
with open(predictions_file, "w", encoding="utf-8") as f:
json.dump(predictions, f, indent=2)
logging.info("Done generating predictions.")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nOperation cancelled by user.")
sys.exit(1)
except Exception as e:
logging.exception("Unhandled error occurred.")
sys.exit(1)