SWEBench updates.
This commit is contained in:
parent
9303736418
commit
4fca32a508
|
|
@ -10,6 +10,8 @@ It:
|
||||||
- Installs `ra-aid` in editable mode + any project dependencies via `uv pip`
|
- Installs `ra-aid` in editable mode + any project dependencies via `uv pip`
|
||||||
- Calls `uv run ra-aid` to generate a patch
|
- Calls `uv run ra-aid` to generate a patch
|
||||||
- Writes out predictions in JSON format
|
- Writes out predictions in JSON format
|
||||||
|
|
||||||
|
No progress bar or spinner is used, allowing `ra-aid` output to stream directly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
@ -24,13 +26,13 @@ from typing import Optional, Tuple, Dict, Any, List
|
||||||
|
|
||||||
from git import Repo
|
from git import Repo
|
||||||
from rich.logging import RichHandler
|
from rich.logging import RichHandler
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
|
|
||||||
|
|
||||||
# If you'd like to override Python versions for specific repos:
|
# If you'd like to override Python versions for specific repos:
|
||||||
PYTHON_VERSION_OVERRIDES = {
|
PYTHON_VERSION_OVERRIDES = {
|
||||||
# "someorg/somerepo": "3.9",
|
# "someorg/somerepo": "3.9",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(log_dir: Path, verbose: bool = False) -> None:
|
def setup_logging(log_dir: Path, verbose: bool = False) -> None:
|
||||||
"""Configure logging with both file and console handlers."""
|
"""Configure logging with both file and console handlers."""
|
||||||
log_dir.mkdir(parents=True, exist_ok=True)
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -55,6 +57,7 @@ def setup_logging(log_dir: Path, verbose: bool = False) -> None:
|
||||||
console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
|
console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||||
root_logger.addHandler(console_handler)
|
root_logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_safely() -> Optional[Any]:
|
def load_dataset_safely() -> Optional[Any]:
|
||||||
"""Load SWE-bench Lite dataset with error handling."""
|
"""Load SWE-bench Lite dataset with error handling."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -65,6 +68,7 @@ def load_dataset_safely() -> Optional[Any]:
|
||||||
logging.error(f"Failed to load dataset: {e}")
|
logging.error(f"Failed to load dataset: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def create_output_dirs() -> Tuple[Path, Path]:
|
def create_output_dirs() -> Tuple[Path, Path]:
|
||||||
"""Create base/log directory structure."""
|
"""Create base/log directory structure."""
|
||||||
date_str = datetime.now().strftime("%Y%m%d")
|
date_str = datetime.now().strftime("%Y%m%d")
|
||||||
|
|
@ -74,6 +78,7 @@ def create_output_dirs() -> Tuple[Path, Path]:
|
||||||
log_dir.mkdir(parents=True, exist_ok=True)
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
return base_dir, log_dir
|
return base_dir, log_dir
|
||||||
|
|
||||||
|
|
||||||
def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
|
def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Create (or reuse) a .venv in 'repo_dir' using 'uv venv'.
|
Create (or reuse) a .venv in 'repo_dir' using 'uv venv'.
|
||||||
|
|
@ -87,7 +92,7 @@ def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
|
||||||
logging.info(f"Removing existing .venv at {venv_dir}")
|
logging.info(f"Removing existing .venv at {venv_dir}")
|
||||||
shutil.rmtree(venv_dir)
|
shutil.rmtree(venv_dir)
|
||||||
|
|
||||||
python_version = PYTHON_VERSION_OVERRIDES.get(repo_name, None)
|
python_version = PYTHON_VERSION_OVERRIDES.get(repo_name, None) or "3.12"
|
||||||
cmd = ["uv", "venv"]
|
cmd = ["uv", "venv"]
|
||||||
if python_version:
|
if python_version:
|
||||||
cmd.append(f"--python={python_version}")
|
cmd.append(f"--python={python_version}")
|
||||||
|
|
@ -98,6 +103,7 @@ def uv_venv(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to create venv in {repo_dir}: {e}")
|
logging.error(f"Failed to create venv in {repo_dir}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def uv_pip_install(repo_dir: Path, args: List[str]) -> None:
|
def uv_pip_install(repo_dir: Path, args: List[str]) -> None:
|
||||||
"""
|
"""
|
||||||
Run 'uv pip install ...' in the specified repo_dir.
|
Run 'uv pip install ...' in the specified repo_dir.
|
||||||
|
|
@ -109,9 +115,11 @@ def uv_pip_install(repo_dir: Path, args: List[str]) -> None:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to run uv pip install {args}: {e}")
|
logging.error(f"Failed to run uv pip install {args}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
|
def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Call 'uv run ra-aid' with the given prompt in the environment.
|
Call 'uv run ra-aid' with the given prompt in the environment,
|
||||||
|
streaming output directly to the console (capture_output=False).
|
||||||
Returns the patch if successful, else None.
|
Returns the patch if successful, else None.
|
||||||
"""
|
"""
|
||||||
cmd = [
|
cmd = [
|
||||||
|
|
@ -119,12 +127,16 @@ def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
|
||||||
"--cowboy-mode",
|
"--cowboy-mode",
|
||||||
"-m", prompt
|
"-m", prompt
|
||||||
]
|
]
|
||||||
|
# We are NOT capturing output, so it streams live:
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(cmd, cwd=repo_dir, text=True, capture_output=True, timeout=300)
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd=repo_dir,
|
||||||
|
text=True,
|
||||||
|
check=False, # We manually handle exit code
|
||||||
|
)
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
logging.error("ra-aid returned non-zero exit code.")
|
logging.error("ra-aid returned non-zero exit code.")
|
||||||
logging.debug(f"stdout: {result.stdout}")
|
|
||||||
logging.debug(f"stderr: {result.stderr}")
|
|
||||||
return None
|
return None
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
logging.error("ra-aid timed out")
|
logging.error("ra-aid timed out")
|
||||||
|
|
@ -137,6 +149,7 @@ def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]:
|
||||||
patch = get_git_patch(repo_dir)
|
patch = get_git_patch(repo_dir)
|
||||||
return patch
|
return patch
|
||||||
|
|
||||||
|
|
||||||
def get_git_patch(repo_dir: Path) -> Optional[str]:
|
def get_git_patch(repo_dir: Path) -> Optional[str]:
|
||||||
"""Generate a git patch from the current changes in `repo_dir`."""
|
"""Generate a git patch from the current changes in `repo_dir`."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -154,6 +167,7 @@ def get_git_patch(repo_dir: Path) -> Optional[str]:
|
||||||
logging.error(f"Failed to generate patch: {e}")
|
logging.error(f"Failed to generate patch: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
|
def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> None:
|
||||||
"""
|
"""
|
||||||
- uv venv .venv --python=xxx (optional)
|
- uv venv .venv --python=xxx (optional)
|
||||||
|
|
@ -189,6 +203,7 @@ def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> Non
|
||||||
if req_dev_file.is_file():
|
if req_dev_file.is_file():
|
||||||
uv_pip_install(repo_dir, ["-r", "requirements-dev.txt"])
|
uv_pip_install(repo_dir, ["-r", "requirements-dev.txt"])
|
||||||
|
|
||||||
|
|
||||||
def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List[str]) -> str:
|
def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List[str]) -> str:
|
||||||
"""
|
"""
|
||||||
Construct the prompt text from problem_statement, FAIL_TO_PASS, PASS_TO_PASS.
|
Construct the prompt text from problem_statement, FAIL_TO_PASS, PASS_TO_PASS.
|
||||||
|
|
@ -202,14 +217,25 @@ def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List
|
||||||
for t in pass_tests:
|
for t in pass_tests:
|
||||||
prompt += f"- {t}\n"
|
prompt += f"- {t}\n"
|
||||||
prompt += "```\n\n"
|
prompt += "```\n\n"
|
||||||
|
prompt += "\n\nYou must run all relevant tests both before and after making changes, and ensure they pass as you do your work."
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def process_instance(
|
def process_instance(
|
||||||
instance: Dict[str, Any],
|
instance: Dict[str, Any],
|
||||||
projects_dir: Path,
|
projects_dir: Path,
|
||||||
reuse_repo: bool,
|
reuse_repo: bool,
|
||||||
force_venv: bool
|
force_venv: bool
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process a single dataset instance without a progress bar/spinner.
|
||||||
|
- Clone or reuse the repo at projects_dir/<instance_id>
|
||||||
|
- Checkout commit
|
||||||
|
- Create or reuse a .venv in that repo
|
||||||
|
- Install ra-aid + any project dependencies
|
||||||
|
- Build prompt, run ra-aid (output streamed to console)
|
||||||
|
- Return prediction dict
|
||||||
|
"""
|
||||||
inst_id = instance.get("instance_id", "<unknown>")
|
inst_id = instance.get("instance_id", "<unknown>")
|
||||||
repo_name = instance["repo"]
|
repo_name = instance["repo"]
|
||||||
commit = instance["base_commit"]
|
commit = instance["base_commit"]
|
||||||
|
|
@ -222,7 +248,6 @@ def process_instance(
|
||||||
if isinstance(pass_tests, str):
|
if isinstance(pass_tests, str):
|
||||||
pass_tests = [pass_tests]
|
pass_tests = [pass_tests]
|
||||||
|
|
||||||
# Build GH URL
|
|
||||||
if "github.com" not in repo_name:
|
if "github.com" not in repo_name:
|
||||||
repo_url = f"https://github.com/{repo_name}.git"
|
repo_url = f"https://github.com/{repo_name}.git"
|
||||||
else:
|
else:
|
||||||
|
|
@ -230,13 +255,11 @@ def process_instance(
|
||||||
|
|
||||||
checkout_dir = projects_dir / f"{inst_id}"
|
checkout_dir = projects_dir / f"{inst_id}"
|
||||||
|
|
||||||
# Clone or reuse
|
|
||||||
try:
|
try:
|
||||||
if not checkout_dir.exists():
|
if not checkout_dir.exists():
|
||||||
logging.info(f"Cloning {repo_url} -> {checkout_dir}")
|
logging.info(f"Cloning {repo_url} -> {checkout_dir}")
|
||||||
repo = Repo.clone_from(repo_url, checkout_dir)
|
repo = Repo.clone_from(repo_url, checkout_dir)
|
||||||
else:
|
else:
|
||||||
# if reuse_repo
|
|
||||||
if reuse_repo:
|
if reuse_repo:
|
||||||
logging.info(f"Reusing existing directory: {checkout_dir}")
|
logging.info(f"Reusing existing directory: {checkout_dir}")
|
||||||
repo = Repo(checkout_dir)
|
repo = Repo(checkout_dir)
|
||||||
|
|
@ -245,7 +268,7 @@ def process_instance(
|
||||||
shutil.rmtree(checkout_dir)
|
shutil.rmtree(checkout_dir)
|
||||||
repo = Repo.clone_from(repo_url, checkout_dir)
|
repo = Repo.clone_from(repo_url, checkout_dir)
|
||||||
|
|
||||||
# checkout commit
|
# checkout correct commit
|
||||||
repo.git.checkout(commit)
|
repo.git.checkout(commit)
|
||||||
|
|
||||||
# set up venv + deps
|
# set up venv + deps
|
||||||
|
|
@ -269,8 +292,11 @@ def process_instance(
|
||||||
"model_name_or_path": "ra-aid"
|
"model_name_or_path": "ra-aid"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(description="Generate predictions for SWE-bench Lite using uv + ra-aid.")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Generate predictions for SWE-bench Lite using uv + ra-aid (no progress bar)."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"output_dir",
|
"output_dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
|
|
@ -307,43 +333,46 @@ def main() -> None:
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
# Create base/log dirs and set up logging
|
||||||
base_dir, log_dir = create_output_dirs()
|
base_dir, log_dir = create_output_dirs()
|
||||||
setup_logging(log_dir, args.verbose)
|
setup_logging(log_dir, args.verbose)
|
||||||
logging.info("Starting script")
|
logging.info("Starting script")
|
||||||
|
|
||||||
|
# Ensure projects dir
|
||||||
args.projects_dir.mkdir(parents=True, exist_ok=True)
|
args.projects_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
dataset = load_dataset_safely()
|
dataset = load_dataset_safely()
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Combine dev + test
|
||||||
all_data = list(dataset["dev"]) + list(dataset["test"])
|
all_data = list(dataset["dev"]) + list(dataset["test"])
|
||||||
|
|
||||||
|
# Ensure output dir
|
||||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
predictions_file = args.output_dir / "predictions.json"
|
predictions_file = args.output_dir / "predictions.json"
|
||||||
predictions: List[Dict[str, str]] = []
|
predictions: List[Dict[str, str]] = []
|
||||||
|
|
||||||
limit = args.num_instances if args.num_instances else len(all_data)
|
limit = args.num_instances if args.num_instances else len(all_data)
|
||||||
|
|
||||||
with Progress(
|
# Just a simple for loop - no progress bar
|
||||||
SpinnerColumn(),
|
logging.info(f"Processing up to {limit} instances.")
|
||||||
TextColumn("[progress.description]{task.description}"),
|
for i, inst in enumerate(all_data):
|
||||||
TimeElapsedColumn(),
|
if i >= limit:
|
||||||
transient=False
|
break
|
||||||
) as progress:
|
|
||||||
task = progress.add_task("Processing instances...", total=limit)
|
|
||||||
for i, inst in enumerate(all_data):
|
|
||||||
if i >= limit:
|
|
||||||
break
|
|
||||||
pred = process_instance(inst, args.projects_dir, args.reuse_repo, args.force_venv)
|
|
||||||
predictions.append(pred)
|
|
||||||
progress.advance(task)
|
|
||||||
|
|
||||||
|
logging.info(f"=== Instance {i+1}/{limit}, ID={inst.get('instance_id')} ===")
|
||||||
|
pred = process_instance(inst, args.projects_dir, args.reuse_repo, args.force_venv)
|
||||||
|
predictions.append(pred)
|
||||||
|
|
||||||
|
# Save predictions
|
||||||
with open(predictions_file, "w", encoding="utf-8") as f:
|
with open(predictions_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(predictions, f, indent=2)
|
json.dump(predictions, f, indent=2)
|
||||||
|
|
||||||
logging.info("Done generating predictions.")
|
logging.info("Done generating predictions.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
main()
|
main()
|
||||||
|
|
@ -352,4 +381,4 @@ if __name__ == "__main__":
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("Unhandled error occurred.")
|
logging.exception("Unhandled error occurred.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
Loading…
Reference in New Issue