SWEBench updates.

This commit is contained in:
AI Christianson 2024-12-31 08:36:19 -05:00
parent 9303736418
commit 4fca32a508
1 changed files with 54 additions and 25 deletions

View File

@ -10,6 +10,8 @@ It:
- Installs `ra-aid` in editable mode + any project dependencies via `uv pip` - Installs `ra-aid` in editable mode + any project dependencies via `uv pip`
- Calls `uv run ra-aid` to generate a patch - Calls `uv run ra-aid` to generate a patch
- Writes out predictions in JSON format - Writes out predictions in JSON format
No progress bar or spinner is used, allowing `ra-aid` output to stream directly.
""" """
import argparse import argparse
@ -24,13 +26,13 @@ from typing import Optional, Tuple, Dict, Any, List
from git import Repo from git import Repo
from rich.logging import RichHandler from rich.logging import RichHandler
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
# If you'd like to override Python versions for specific repos: # If you'd like to override Python versions for specific repos:
PYTHON_VERSION_OVERRIDES = { PYTHON_VERSION_OVERRIDES = {
# "someorg/somerepo": "3.9", # "someorg/somerepo": "3.9",
} }
def setup_logging(log_dir: Path, verbose: bool = False) -> None: def setup_logging(log_dir: Path, verbose: bool = False) -> None:
"""Configure logging with both file and console handlers.""" """Configure logging with both file and console handlers."""
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
@ -55,6 +57,7 @@ def setup_logging(log_dir: Path, verbose: bool = False) -> None:
console_handler.setLevel(logging.DEBUG if verbose else logging.INFO) console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
root_logger.addHandler(console_handler) root_logger.addHandler(console_handler)
def load_dataset_safely() -> Optional[Any]: def load_dataset_safely() -> Optional[Any]:
"""Load SWE-bench Lite dataset with error handling.""" """Load SWE-bench Lite dataset with error handling."""
try: try:
@ -65,6 +68,7 @@ def load_dataset_safely() -> Optional[Any]:
logging.error(f"Failed to load dataset: {e}") logging.error(f"Failed to load dataset: {e}")
return None return None
def create_output_dirs() -> Tuple[Path, Path]: def create_output_dirs() -> Tuple[Path, Path]:
"""Create base/log directory structure.""" """Create base/log directory structure."""
date_str = datetime.now().strftime("%Y%m%d") date_str = datetime.now().strftime("%Y%m%d")
@ -74,6 +78,7 @@ def create_output_dirs() -> Tuple[Path, Path]:
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
return base_dir, log_dir return base_dir, log_dir
def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None: def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
""" """
Create (or reuse) a .venv in 'repo_dir' using 'uv venv'. Create (or reuse) a .venv in 'repo_dir' using 'uv venv'.
@ -87,7 +92,7 @@ def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
logging.info(f"Removing existing .venv at {venv_dir}") logging.info(f"Removing existing .venv at {venv_dir}")
shutil.rmtree(venv_dir) shutil.rmtree(venv_dir)
python_version = PYTHON_VERSION_OVERRIDES.get(repo_name, None) python_version = PYTHON_VERSION_OVERRIDES.get(repo_name, None) or "3.12"
cmd = ["uv", "venv"] cmd = ["uv", "venv"]
if python_version: if python_version:
cmd.append(f"--python={python_version}") cmd.append(f"--python={python_version}")
@ -98,6 +103,7 @@ def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
except Exception as e: except Exception as e:
logging.error(f"Failed to create venv in {repo_dir}: {e}") logging.error(f"Failed to create venv in {repo_dir}: {e}")
def uv_pip_install(repo_dir: Path, args: List[str]) -> None: def uv_pip_install(repo_dir: Path, args: List[str]) -> None:
""" """
Run 'uv pip install ...' in the specified repo_dir. Run 'uv pip install ...' in the specified repo_dir.
@ -109,9 +115,11 @@ def uv_pip_install(repo_dir: Path, args: List[str]) -> None:
except Exception as e: except Exception as e:
logging.error(f"Failed to run uv pip install {args}: {e}") logging.error(f"Failed to run uv pip install {args}: {e}")
def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]: def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
""" """
Call 'uv run ra-aid' with the given prompt in the environment. Call 'uv run ra-aid' with the given prompt in the environment,
streaming output directly to the console (capture_output=False).
Returns the patch if successful, else None. Returns the patch if successful, else None.
""" """
cmd = [ cmd = [
@ -119,12 +127,16 @@ def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
"--cowboy-mode", "--cowboy-mode",
"-m", prompt "-m", prompt
] ]
# We are NOT capturing output, so it streams live:
try: try:
result = subprocess.run(cmd, cwd=repo_dir, text=True, capture_output=True, timeout=300) result = subprocess.run(
cmd,
cwd=repo_dir,
text=True,
check=False, # We manually handle exit code
)
if result.returncode != 0: if result.returncode != 0:
logging.error("ra-aid returned non-zero exit code.") logging.error("ra-aid returned non-zero exit code.")
logging.debug(f"stdout: {result.stdout}")
logging.debug(f"stderr: {result.stderr}")
return None return None
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
logging.error("ra-aid timed out") logging.error("ra-aid timed out")
@ -137,6 +149,7 @@ def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
patch = get_git_patch(repo_dir) patch = get_git_patch(repo_dir)
return patch return patch
def get_git_patch(repo_dir: Path) -> Optional[str]: def get_git_patch(repo_dir: Path) -> Optional[str]:
"""Generate a git patch from the current changes in `repo_dir`.""" """Generate a git patch from the current changes in `repo_dir`."""
try: try:
@ -154,6 +167,7 @@ def get_git_patch(repo_dir: Path) -> Optional[str]:
logging.error(f"Failed to generate patch: {e}") logging.error(f"Failed to generate patch: {e}")
return None return None
def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> None: def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
""" """
- uv venv .venv --python=xxx (optional) - uv venv .venv --python=xxx (optional)
@ -189,6 +203,7 @@ def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> Non
if req_dev_file.is_file(): if req_dev_file.is_file():
uv_pip_install(repo_dir, ["-r", "requirements-dev.txt"]) uv_pip_install(repo_dir, ["-r", "requirements-dev.txt"])
def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List[str]) -> str: def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List[str]) -> str:
""" """
Construct the prompt text from problem_statement, FAIL_TO_PASS, PASS_TO_PASS. Construct the prompt text from problem_statement, FAIL_TO_PASS, PASS_TO_PASS.
@ -202,14 +217,25 @@ def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List
for t in pass_tests: for t in pass_tests:
prompt += f"- {t}\n" prompt += f"- {t}\n"
prompt += "```\n\n" prompt += "```\n\n"
prompt += "\n\nYou must run all relevant tests both before and after making changes, and ensure they pass as you do your work."
return prompt return prompt
def process_instance( def process_instance(
instance: Dict[str, Any], instance: Dict[str, Any],
projects_dir: Path, projects_dir: Path,
reuse_repo: bool, reuse_repo: bool,
force_venv: bool force_venv: bool
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""
Process a single dataset instance without a progress bar/spinner.
- Clone or reuse the repo at projects_dir/<instance_id>
- Checkout commit
- Create or reuse a .venv in that repo
- Install ra-aid + any project dependencies
- Build prompt, run ra-aid (output streamed to console)
- Return prediction dict
"""
inst_id = instance.get("instance_id", "<unknown>") inst_id = instance.get("instance_id", "<unknown>")
repo_name = instance["repo"] repo_name = instance["repo"]
commit = instance["base_commit"] commit = instance["base_commit"]
@ -222,7 +248,6 @@ def process_instance(
if isinstance(pass_tests, str): if isinstance(pass_tests, str):
pass_tests = [pass_tests] pass_tests = [pass_tests]
# Build GH URL
if "github.com" not in repo_name: if "github.com" not in repo_name:
repo_url = f"https://github.com/{repo_name}.git" repo_url = f"https://github.com/{repo_name}.git"
else: else:
@ -230,13 +255,11 @@ def process_instance(
checkout_dir = projects_dir / f"{inst_id}" checkout_dir = projects_dir / f"{inst_id}"
# Clone or reuse
try: try:
if not checkout_dir.exists(): if not checkout_dir.exists():
logging.info(f"Cloning {repo_url} -> {checkout_dir}") logging.info(f"Cloning {repo_url} -> {checkout_dir}")
repo = Repo.clone_from(repo_url, checkout_dir) repo = Repo.clone_from(repo_url, checkout_dir)
else: else:
# if reuse_repo
if reuse_repo: if reuse_repo:
logging.info(f"Reusing existing directory: {checkout_dir}") logging.info(f"Reusing existing directory: {checkout_dir}")
repo = Repo(checkout_dir) repo = Repo(checkout_dir)
@ -245,7 +268,7 @@ def process_instance(
shutil.rmtree(checkout_dir) shutil.rmtree(checkout_dir)
repo = Repo.clone_from(repo_url, checkout_dir) repo = Repo.clone_from(repo_url, checkout_dir)
# checkout commit # checkout correct commit
repo.git.checkout(commit) repo.git.checkout(commit)
# set up venv + deps # set up venv + deps
@ -269,8 +292,11 @@ def process_instance(
"model_name_or_path": "ra-aid" "model_name_or_path": "ra-aid"
} }
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser(description="Generate predictions for SWE-bench Lite using uv + ra-aid.") parser = argparse.ArgumentParser(
description="Generate predictions for SWE-bench Lite using uv + ra-aid (no progress bar)."
)
parser.add_argument( parser.add_argument(
"output_dir", "output_dir",
type=Path, type=Path,
@ -307,43 +333,46 @@ def main() -> None:
from datasets import load_dataset from datasets import load_dataset
# Create base/log dirs and set up logging
base_dir, log_dir = create_output_dirs() base_dir, log_dir = create_output_dirs()
setup_logging(log_dir, args.verbose) setup_logging(log_dir, args.verbose)
logging.info("Starting script") logging.info("Starting script")
# Ensure projects dir
args.projects_dir.mkdir(parents=True, exist_ok=True) args.projects_dir.mkdir(parents=True, exist_ok=True)
# Load dataset
dataset = load_dataset_safely() dataset = load_dataset_safely()
if dataset is None: if dataset is None:
sys.exit(1) sys.exit(1)
# Combine dev + test
all_data = list(dataset["dev"]) + list(dataset["test"]) all_data = list(dataset["dev"]) + list(dataset["test"])
# Ensure output dir
args.output_dir.mkdir(parents=True, exist_ok=True) args.output_dir.mkdir(parents=True, exist_ok=True)
predictions_file = args.output_dir / "predictions.json" predictions_file = args.output_dir / "predictions.json"
predictions: List[Dict[str, str]] = [] predictions: List[Dict[str, str]] = []
limit = args.num_instances if args.num_instances else len(all_data) limit = args.num_instances if args.num_instances else len(all_data)
with Progress( # Just a simple for loop - no progress bar
SpinnerColumn(), logging.info(f"Processing up to {limit} instances.")
TextColumn("[progress.description]{task.description}"), for i, inst in enumerate(all_data):
TimeElapsedColumn(), if i >= limit:
transient=False break
) as progress:
task = progress.add_task("Processing instances...", total=limit)
for i, inst in enumerate(all_data):
if i >= limit:
break
pred = process_instance(inst, args.projects_dir, args.reuse_repo, args.force_venv)
predictions.append(pred)
progress.advance(task)
logging.info(f"=== Instance {i+1}/{limit}, ID={inst.get('instance_id')} ===")
pred = process_instance(inst, args.projects_dir, args.reuse_repo, args.force_venv)
predictions.append(pred)
# Save predictions
with open(predictions_file, "w", encoding="utf-8") as f: with open(predictions_file, "w", encoding="utf-8") as f:
json.dump(predictions, f, indent=2) json.dump(predictions, f, indent=2)
logging.info("Done generating predictions.") logging.info("Done generating predictions.")
if __name__ == "__main__": if __name__ == "__main__":
try: try:
main() main()
@ -352,4 +381,4 @@ if __name__ == "__main__":
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
logging.exception("Unhandled error occurred.") logging.exception("Unhandled error occurred.")
sys.exit(1) sys.exit(1)