SWEBench updates.

This commit is contained in:
AI Christianson 2024-12-31 08:36:19 -05:00
parent 9303736418
commit 4fca32a508
1 changed files with 54 additions and 25 deletions

View File

@ -10,6 +10,8 @@ It:
- Installs `ra-aid` in editable mode + any project dependencies via `uv pip`
- Calls `uv run ra-aid` to generate a patch
- Writes out predictions in JSON format
No progress bar or spinner is used, allowing `ra-aid` output to stream directly.
"""
import argparse
@ -24,13 +26,13 @@ from typing import Optional, Tuple, Dict, Any, List
from git import Repo
from rich.logging import RichHandler
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
# If you'd like to override Python versions for specific repos:
PYTHON_VERSION_OVERRIDES = {
# "someorg/somerepo": "3.9",
}
def setup_logging(log_dir: Path, verbose: bool = False) -> None:
"""Configure logging with both file and console handlers."""
log_dir.mkdir(parents=True, exist_ok=True)
@ -55,6 +57,7 @@ def setup_logging(log_dir: Path, verbose: bool = False) -> None:
console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
root_logger.addHandler(console_handler)
def load_dataset_safely() -> Optional[Any]:
"""Load SWE-bench Lite dataset with error handling."""
try:
@ -65,6 +68,7 @@ def load_dataset_safely() -> Optional[Any]:
logging.error(f"Failed to load dataset: {e}")
return None
def create_output_dirs() -> Tuple[Path, Path]:
"""Create base/log directory structure."""
date_str = datetime.now().strftime("%Y%m%d")
@ -74,6 +78,7 @@ def create_output_dirs() -> Tuple[Path, Path]:
log_dir.mkdir(parents=True, exist_ok=True)
return base_dir, log_dir
def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
"""
Create (or reuse) a .venv in 'repo_dir' using 'uv venv'.
@ -87,7 +92,7 @@ def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
logging.info(f"Removing existing .venv at {venv_dir}")
shutil.rmtree(venv_dir)
python_version = PYTHON_VERSION_OVERRIDES.get(repo_name, None)
python_version = PYTHON_VERSION_OVERRIDES.get(repo_name, None) or "3.12"
cmd = ["uv", "venv"]
if python_version:
cmd.append(f"--python={python_version}")
@ -98,6 +103,7 @@ def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
except Exception as e:
logging.error(f"Failed to create venv in {repo_dir}: {e}")
def uv_pip_install(repo_dir: Path, args: List[str]) -> None:
"""
Run 'uv pip install ...' in the specified repo_dir.
@ -109,9 +115,11 @@ def uv_pip_install(repo_dir: Path, args: List[str]) -> None:
except Exception as e:
logging.error(f"Failed to run uv pip install {args}: {e}")
def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
"""
Call 'uv run ra-aid' with the given prompt in the environment.
Call 'uv run ra-aid' with the given prompt in the environment,
streaming output directly to the console (capture_output=False).
Returns the patch if successful, else None.
"""
cmd = [
@ -119,12 +127,16 @@ def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
"--cowboy-mode",
"-m", prompt
]
# We are NOT capturing output, so it streams live:
try:
result = subprocess.run(cmd, cwd=repo_dir, text=True, capture_output=True, timeout=300)
result = subprocess.run(
cmd,
cwd=repo_dir,
text=True,
check=False, # We manually handle exit code
)
if result.returncode != 0:
logging.error("ra-aid returned non-zero exit code.")
logging.debug(f"stdout: {result.stdout}")
logging.debug(f"stderr: {result.stderr}")
return None
except subprocess.TimeoutExpired:
logging.error("ra-aid timed out")
@ -137,6 +149,7 @@ def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
patch = get_git_patch(repo_dir)
return patch
def get_git_patch(repo_dir: Path) -> Optional[str]:
"""Generate a git patch from the current changes in `repo_dir`."""
try:
@ -154,6 +167,7 @@ def get_git_patch(repo_dir: Path) -> Optional[str]:
logging.error(f"Failed to generate patch: {e}")
return None
def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
"""
- uv venv .venv --python=xxx (optional)
@ -189,6 +203,7 @@ def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> Non
if req_dev_file.is_file():
uv_pip_install(repo_dir, ["-r", "requirements-dev.txt"])
def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List[str]) -> str:
"""
Construct the prompt text from problem_statement, FAIL_TO_PASS, PASS_TO_PASS.
@ -202,14 +217,25 @@ def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List
for t in pass_tests:
prompt += f"- {t}\n"
prompt += "```\n\n"
prompt += "\n\nYou must run all relevant tests both before and after making changes, and ensure they pass as you do your work."
return prompt
def process_instance(
instance: Dict[str, Any],
projects_dir: Path,
reuse_repo: bool,
force_venv: bool
) -> Dict[str, Any]:
"""
Process a single dataset instance without a progress bar/spinner.
- Clone or reuse the repo at projects_dir/<instance_id>
- Checkout commit
- Create or reuse a .venv in that repo
- Install ra-aid + any project dependencies
- Build prompt, run ra-aid (output streamed to console)
- Return prediction dict
"""
inst_id = instance.get("instance_id", "<unknown>")
repo_name = instance["repo"]
commit = instance["base_commit"]
@ -222,7 +248,6 @@ def process_instance(
if isinstance(pass_tests, str):
pass_tests = [pass_tests]
# Build GH URL
if "github.com" not in repo_name:
repo_url = f"https://github.com/{repo_name}.git"
else:
@ -230,13 +255,11 @@ def process_instance(
checkout_dir = projects_dir / f"{inst_id}"
# Clone or reuse
try:
if not checkout_dir.exists():
logging.info(f"Cloning {repo_url} -> {checkout_dir}")
repo = Repo.clone_from(repo_url, checkout_dir)
else:
# if reuse_repo
if reuse_repo:
logging.info(f"Reusing existing directory: {checkout_dir}")
repo = Repo(checkout_dir)
@ -245,7 +268,7 @@ def process_instance(
shutil.rmtree(checkout_dir)
repo = Repo.clone_from(repo_url, checkout_dir)
# checkout commit
# checkout correct commit
repo.git.checkout(commit)
# set up venv + deps
@ -269,8 +292,11 @@ def process_instance(
"model_name_or_path": "ra-aid"
}
def main() -> None:
parser = argparse.ArgumentParser(description="Generate predictions for SWE-bench Lite using uv + ra-aid.")
parser = argparse.ArgumentParser(
description="Generate predictions for SWE-bench Lite using uv + ra-aid (no progress bar)."
)
parser.add_argument(
"output_dir",
type=Path,
@ -307,43 +333,46 @@ def main() -> None:
from datasets import load_dataset
# Create base/log dirs and set up logging
base_dir, log_dir = create_output_dirs()
setup_logging(log_dir, args.verbose)
logging.info("Starting script")
# Ensure projects dir
args.projects_dir.mkdir(parents=True, exist_ok=True)
# Load dataset
dataset = load_dataset_safely()
if dataset is None:
sys.exit(1)
# Combine dev + test
all_data = list(dataset["dev"]) + list(dataset["test"])
# Ensure output dir
args.output_dir.mkdir(parents=True, exist_ok=True)
predictions_file = args.output_dir / "predictions.json"
predictions: List[Dict[str, str]] = []
limit = args.num_instances if args.num_instances else len(all_data)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
TimeElapsedColumn(),
transient=False
) as progress:
task = progress.add_task("Processing instances...", total=limit)
for i, inst in enumerate(all_data):
if i >= limit:
break
pred = process_instance(inst, args.projects_dir, args.reuse_repo, args.force_venv)
predictions.append(pred)
progress.advance(task)
# Just a simple for loop - no progress bar
logging.info(f"Processing up to {limit} instances.")
for i, inst in enumerate(all_data):
if i >= limit:
break
logging.info(f"=== Instance {i+1}/{limit}, ID={inst.get('instance_id')} ===")
pred = process_instance(inst, args.projects_dir, args.reuse_repo, args.force_venv)
predictions.append(pred)
# Save predictions
with open(predictions_file, "w", encoding="utf-8") as f:
json.dump(predictions, f, indent=2)
logging.info("Done generating predictions.")
if __name__ == "__main__":
try:
main()
@ -352,4 +381,4 @@ if __name__ == "__main__":
sys.exit(1)
except Exception as e:
logging.exception("Unhandled error occurred.")
sys.exit(1)
sys.exit(1)