SWEBench updates.
This commit is contained in:
parent
9303736418
commit
4fca32a508
|
|
@ -10,6 +10,8 @@ It:
|
|||
- Installs `ra-aid` in editable mode + any project dependencies via `uv pip`
|
||||
- Calls `uv run ra-aid` to generate a patch
|
||||
- Writes out predictions in JSON format
|
||||
|
||||
No progress bar or spinner is used, allowing `ra-aid` output to stream directly.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -24,13 +26,13 @@ from typing import Optional, Tuple, Dict, Any, List
|
|||
|
||||
from git import Repo
|
||||
from rich.logging import RichHandler
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
|
||||
|
||||
# If you'd like to override Python versions for specific repos:
|
||||
PYTHON_VERSION_OVERRIDES = {
|
||||
# "someorg/somerepo": "3.9",
|
||||
}
|
||||
|
||||
|
||||
def setup_logging(log_dir: Path, verbose: bool = False) -> None:
|
||||
"""Configure logging with both file and console handlers."""
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -55,6 +57,7 @@ def setup_logging(log_dir: Path, verbose: bool = False) -> None:
|
|||
console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
|
||||
def load_dataset_safely() -> Optional[Any]:
|
||||
"""Load SWE-bench Lite dataset with error handling."""
|
||||
try:
|
||||
|
|
@ -65,6 +68,7 @@ def load_dataset_safely() -> Optional[Any]:
|
|||
logging.error(f"Failed to load dataset: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_output_dirs() -> Tuple[Path, Path]:
|
||||
"""Create base/log directory structure."""
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
|
|
@ -74,6 +78,7 @@ def create_output_dirs() -> Tuple[Path, Path]:
|
|||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
return base_dir, log_dir
|
||||
|
||||
|
||||
def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
|
||||
"""
|
||||
Create (or reuse) a .venv in 'repo_dir' using 'uv venv'.
|
||||
|
|
@ -87,7 +92,7 @@ def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
|
|||
logging.info(f"Removing existing .venv at {venv_dir}")
|
||||
shutil.rmtree(venv_dir)
|
||||
|
||||
python_version = PYTHON_VERSION_OVERRIDES.get(repo_name, None)
|
||||
python_version = PYTHON_VERSION_OVERRIDES.get(repo_name, None) or "3.12"
|
||||
cmd = ["uv", "venv"]
|
||||
if python_version:
|
||||
cmd.append(f"--python={python_version}")
|
||||
|
|
@ -98,6 +103,7 @@ def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
|
|||
except Exception as e:
|
||||
logging.error(f"Failed to create venv in {repo_dir}: {e}")
|
||||
|
||||
|
||||
def uv_pip_install(repo_dir: Path, args: List[str]) -> None:
|
||||
"""
|
||||
Run 'uv pip install ...' in the specified repo_dir.
|
||||
|
|
@ -109,9 +115,11 @@ def uv_pip_install(repo_dir: Path, args: List[str]) -> None:
|
|||
except Exception as e:
|
||||
logging.error(f"Failed to run uv pip install {args}: {e}")
|
||||
|
||||
|
||||
def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
|
||||
"""
|
||||
Call 'uv run ra-aid' with the given prompt in the environment.
|
||||
Call 'uv run ra-aid' with the given prompt in the environment,
|
||||
streaming output directly to the console (capture_output=False).
|
||||
Returns the patch if successful, else None.
|
||||
"""
|
||||
cmd = [
|
||||
|
|
@ -119,12 +127,16 @@ def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
|
|||
"--cowboy-mode",
|
||||
"-m", prompt
|
||||
]
|
||||
# We are NOT capturing output, so it streams live:
|
||||
try:
|
||||
result = subprocess.run(cmd, cwd=repo_dir, text=True, capture_output=True, timeout=300)
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=repo_dir,
|
||||
text=True,
|
||||
check=False, # We manually handle exit code
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logging.error("ra-aid returned non-zero exit code.")
|
||||
logging.debug(f"stdout: {result.stdout}")
|
||||
logging.debug(f"stderr: {result.stderr}")
|
||||
return None
|
||||
except subprocess.TimeoutExpired:
|
||||
logging.error("ra-aid timed out")
|
||||
|
|
@ -137,6 +149,7 @@ def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
|
|||
patch = get_git_patch(repo_dir)
|
||||
return patch
|
||||
|
||||
|
||||
def get_git_patch(repo_dir: Path) -> Optional[str]:
|
||||
"""Generate a git patch from the current changes in `repo_dir`."""
|
||||
try:
|
||||
|
|
@ -154,6 +167,7 @@ def get_git_patch(repo_dir: Path) -> Optional[str]:
|
|||
logging.error(f"Failed to generate patch: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
|
||||
"""
|
||||
- uv venv .venv --python=xxx (optional)
|
||||
|
|
@ -189,6 +203,7 @@ def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> Non
|
|||
if req_dev_file.is_file():
|
||||
uv_pip_install(repo_dir, ["-r", "requirements-dev.txt"])
|
||||
|
||||
|
||||
def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List[str]) -> str:
|
||||
"""
|
||||
Construct the prompt text from problem_statement, FAIL_TO_PASS, PASS_TO_PASS.
|
||||
|
|
@ -202,14 +217,25 @@ def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List
|
|||
for t in pass_tests:
|
||||
prompt += f"- {t}\n"
|
||||
prompt += "```\n\n"
|
||||
prompt += "\n\nYou must run all relevant tests both before and after making changes, and ensure they pass as you do your work."
|
||||
return prompt
|
||||
|
||||
|
||||
def process_instance(
|
||||
instance: Dict[str, Any],
|
||||
projects_dir: Path,
|
||||
reuse_repo: bool,
|
||||
force_venv: bool
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a single dataset instance without a progress bar/spinner.
|
||||
- Clone or reuse the repo at projects_dir/<instance_id>
|
||||
- Checkout commit
|
||||
- Create or reuse a .venv in that repo
|
||||
- Install ra-aid + any project dependencies
|
||||
- Build prompt, run ra-aid (output streamed to console)
|
||||
- Return prediction dict
|
||||
"""
|
||||
inst_id = instance.get("instance_id", "<unknown>")
|
||||
repo_name = instance["repo"]
|
||||
commit = instance["base_commit"]
|
||||
|
|
@ -222,7 +248,6 @@ def process_instance(
|
|||
if isinstance(pass_tests, str):
|
||||
pass_tests = [pass_tests]
|
||||
|
||||
# Build GH URL
|
||||
if "github.com" not in repo_name:
|
||||
repo_url = f"https://github.com/{repo_name}.git"
|
||||
else:
|
||||
|
|
@ -230,13 +255,11 @@ def process_instance(
|
|||
|
||||
checkout_dir = projects_dir / f"{inst_id}"
|
||||
|
||||
# Clone or reuse
|
||||
try:
|
||||
if not checkout_dir.exists():
|
||||
logging.info(f"Cloning {repo_url} -> {checkout_dir}")
|
||||
repo = Repo.clone_from(repo_url, checkout_dir)
|
||||
else:
|
||||
# if reuse_repo
|
||||
if reuse_repo:
|
||||
logging.info(f"Reusing existing directory: {checkout_dir}")
|
||||
repo = Repo(checkout_dir)
|
||||
|
|
@ -245,7 +268,7 @@ def process_instance(
|
|||
shutil.rmtree(checkout_dir)
|
||||
repo = Repo.clone_from(repo_url, checkout_dir)
|
||||
|
||||
# checkout commit
|
||||
# checkout correct commit
|
||||
repo.git.checkout(commit)
|
||||
|
||||
# set up venv + deps
|
||||
|
|
@ -269,8 +292,11 @@ def process_instance(
|
|||
"model_name_or_path": "ra-aid"
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Generate predictions for SWE-bench Lite using uv + ra-aid.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate predictions for SWE-bench Lite using uv + ra-aid (no progress bar)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_dir",
|
||||
type=Path,
|
||||
|
|
@ -307,43 +333,46 @@ def main() -> None:
|
|||
|
||||
from datasets import load_dataset
|
||||
|
||||
# Create base/log dirs and set up logging
|
||||
base_dir, log_dir = create_output_dirs()
|
||||
setup_logging(log_dir, args.verbose)
|
||||
logging.info("Starting script")
|
||||
|
||||
# Ensure projects dir
|
||||
args.projects_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset_safely()
|
||||
if dataset is None:
|
||||
sys.exit(1)
|
||||
|
||||
# Combine dev + test
|
||||
all_data = list(dataset["dev"]) + list(dataset["test"])
|
||||
|
||||
# Ensure output dir
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
predictions_file = args.output_dir / "predictions.json"
|
||||
predictions: List[Dict[str, str]] = []
|
||||
|
||||
limit = args.num_instances if args.num_instances else len(all_data)
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
TimeElapsedColumn(),
|
||||
transient=False
|
||||
) as progress:
|
||||
task = progress.add_task("Processing instances...", total=limit)
|
||||
for i, inst in enumerate(all_data):
|
||||
if i >= limit:
|
||||
break
|
||||
pred = process_instance(inst, args.projects_dir, args.reuse_repo, args.force_venv)
|
||||
predictions.append(pred)
|
||||
progress.advance(task)
|
||||
# Just a simple for loop - no progress bar
|
||||
logging.info(f"Processing up to {limit} instances.")
|
||||
for i, inst in enumerate(all_data):
|
||||
if i >= limit:
|
||||
break
|
||||
|
||||
logging.info(f"=== Instance {i+1}/{limit}, ID={inst.get('instance_id')} ===")
|
||||
pred = process_instance(inst, args.projects_dir, args.reuse_repo, args.force_venv)
|
||||
predictions.append(pred)
|
||||
|
||||
# Save predictions
|
||||
with open(predictions_file, "w", encoding="utf-8") as f:
|
||||
json.dump(predictions, f, indent=2)
|
||||
|
||||
logging.info("Done generating predictions.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
|
|
@ -352,4 +381,4 @@ if __name__ == "__main__":
|
|||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logging.exception("Unhandled error occurred.")
|
||||
sys.exit(1)
|
||||
sys.exit(1)
|
||||
Loading…
Reference in New Issue