SWEBench updates.

This commit is contained in:
AI Christianson 2024-12-30 13:25:51 -05:00
parent 34acb462f9
commit 8b7eb76be6
1 changed files with 146 additions and 313 deletions

View File

@ -1,13 +1,12 @@
#!/usr/bin/env python3
"""
Script to generate SWE-bench dataset for RA.Aid evaluation.
This is a work in progress and is not yet functional.
This script handles:
- Loading the SWE-bench Lite dataset
- Creating dated output directories
- Setting up logging infrastructure
- Processing dataset instances (placeholder)
Script to generate predictions for SWE-bench Lite (princeton-nlp/SWE-bench_Lite).
This script:
- Loads the SWE-bench Lite dataset
- Clones each repo at the specified commit
- Forms a prompt from the instance fields (problem_statement, FAIL_TO_PASS, PASS_TO_PASS)
- Calls ra-aid to create a patch
- Writes out predictions in the required JSON format
"""
import argparse
@ -19,28 +18,22 @@ import sys
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple, Dict, Any
from typing import Optional, Tuple, Dict, Any, List
from datasets import load_dataset
from git import Repo
from rich.logging import RichHandler
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
def setup_logging(log_dir: Path, verbose: bool = False) -> None:
"""Configure logging with both file and console handlers.
Args:
log_dir: Directory to store log files
verbose: Whether to enable debug logging
"""
"""Configure logging with both file and console handlers."""
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / "generate_dataset.log"
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG if verbose else logging.INFO)
# File handler with detailed formatting
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
@ -48,8 +41,7 @@ def setup_logging(log_dir: Path, verbose: bool = False) -> None:
)
file_handler.setFormatter(file_formatter)
root_logger.addHandler(file_handler)
# Console handler with rich formatting
console_handler = RichHandler(
rich_tracebacks=True,
show_time=False,
@ -58,317 +50,166 @@ def setup_logging(log_dir: Path, verbose: bool = False) -> None:
console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
root_logger.addHandler(console_handler)
def load_dataset_safely() -> Optional[dict]:
"""Load SWE-bench dataset with error handling.
Returns:
Dataset object if successful, None otherwise
"""
def load_dataset_safely() -> Optional[Any]:
"""Load SWE-bench Lite dataset with error handling."""
try:
dataset = load_dataset("princeton-nlp/SWE-bench", "default")
dataset = load_dataset("princeton-nlp/SWE-bench_Lite")
return dataset
except Exception as e:
logging.error(f"Failed to load dataset: {e}")
return None
def create_output_dirs() -> Tuple[Path, Path]:
"""Create dated output directory structure.
Returns:
Tuple of (output_dir, log_dir) paths
"""
"""Create base/log directory structure."""
date_str = datetime.now().strftime("%Y%m%d")
base_dir = Path("evaluation") / "default" / f"{date_str}_raaid"
log_dir = base_dir / "logs"
base_dir.mkdir(parents=True, exist_ok=True)
log_dir.mkdir(parents=True, exist_ok=True)
return base_dir, log_dir
def process_dataset_instance(instance: Dict[str, Any], output_dir: Path) -> bool:
"""Process a single dataset instance.
Args:
instance: Dataset instance containing problem information
output_dir: Directory to store output files
Returns:
bool: True if processing was successful, False otherwise
"""
try:
# Required fields
logging.debug(f"Instance data: {instance}")
logging.debug(f"Instance keys: {instance.keys()}")
instance_id = str(instance['id']) # Use id as unique identifier
repo_url = instance['repo_url']
commit_id = instance['code_before']['revision']
# Issue description
issue_title = instance['issue_title']
issue_body = instance.get('issue_body', '') # Optional with default
issue_desc = f"{issue_title}\n\n{issue_body}"
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Clone repository
repo = Repo.clone_from(repo_url, temp_path)
repo.git.checkout(commit_id)
# Format input for ra-aid
issue_desc = instance['problem_statement']
test_info = instance.get('test_information', '')
# Run ra-aid
patch = run_raaid(temp_path, issue_desc, test_info)
if not patch:
return False
# Write prediction
write_prediction(output_dir, instance_id, patch)
return True
except Exception as e:
# Use instance.get() to avoid KeyError in error logging
instance_id = instance.get('id', '<unknown>')
logging.error(f"Failed to process instance {instance_id}: {e}")
return False
def parse_test_information(test_info: str) -> Tuple[list, list]:
"""Parse test information into failing and passing test lists.
Args:
test_info: Raw test information string
Returns:
Tuple[list, list]: Lists of (fail_to_pass, pass_to_pass) tests
Raises:
ValueError: If required test sections are missing or malformed
"""
fail_to_pass = []
pass_to_pass = []
# Split into sections
sections = test_info.split('\n\n')
current_section = None
for section in sections:
section = section.strip()
if not section:
continue
if section.startswith('FAIL_TO_PASS:'):
current_section = 'fail'
tests = section.replace('FAIL_TO_PASS:', '').strip().split('\n')
fail_to_pass.extend(test.strip() for test in tests if test.strip())
elif section.startswith('PASS_TO_PASS:'):
current_section = 'pass'
tests = section.replace('PASS_TO_PASS:', '').strip().split('\n')
pass_to_pass.extend(test.strip() for test in tests if test.strip())
if not fail_to_pass:
raise ValueError("No FAIL_TO_PASS tests found in test information")
return fail_to_pass, pass_to_pass
def run_raaid(
repo_dir: Path,
problem_statement: str,
fail_tests: List[str],
pass_tests: List[str]
) -> Optional[str]:
"""Run ra-aid on the problem statement, returning a generated patch if possible."""
# Create prompt
prompt = f"{problem_statement}\n\nTests that need to be fixed:\n```\n"
for t in fail_tests:
prompt += f"- {t}\n"
prompt += "```\n\n"
if pass_tests:
prompt += "Tests that must remain passing:\n```\n"
for t in pass_tests:
prompt += f"- {t}\n"
prompt += "```\n\n"
def run_raaid(repo_dir: Path, issue_desc: str, test_info: str) -> Optional[str]:
"""Run ra-aid on the problem and capture output.
Args:
repo_dir: Path to repository directory
issue_desc: Problem description
test_info: Additional test information
Returns:
Optional[str]: Generated patch if successful, None otherwise
"""
try:
# Parse test information
fail_to_pass, pass_to_pass = parse_test_information(test_info)
# Format prompt with clear sections
prompt = (
f"{issue_desc}\n\n"
"Tests that need to be fixed:\n"
"```\n"
+ "\n".join(f"- {test}" for test in fail_to_pass)
+ "\n```\n\n"
)
if pass_to_pass:
prompt += (
"Tests that must remain passing:\n"
"```\n"
+ "\n".join(f"- {test}" for test in pass_to_pass)
+ "\n```\n\n"
)
except ValueError as e:
logging.error(f"Invalid test information format: {e}")
return None
except Exception as e:
logging.error(f"Error parsing test information: {e}")
return None
# Implementation phase
impl_cmd = [
'ra-aid',
'--cowboy-mode',
'-m', prompt,
]
try:
# Configure ra-aid with appropriate flags
cmd = [
'ra-aid',
'-m', prompt,
'--research-only', # First analyze without implementation
'--expert-provider', 'openai', # Use OpenAI for expert knowledge
'--verbose' # Enable detailed logging
]
# First run - research phase
result = subprocess.run(
cmd,
impl_result = subprocess.run(
impl_cmd,
cwd=repo_dir,
capture_output=True,
text=True,
timeout=300 # 5 minute timeout for research
timeout=300
)
if result.returncode != 0:
logging.error("Research phase failed")
if impl_result.returncode != 0:
logging.error("ra-aid returned non-zero exit code.")
logging.debug(f"stdout: {impl_result.stdout}")
logging.debug(f"stderr: {impl_result.stderr}")
return None
# Second run - implementation phase
cmd = [
'ra-aid',
'-m', prompt,
'--expert-provider', 'openai',
'--verbose'
]
result = subprocess.run(
cmd,
cwd=repo_dir,
capture_output=True,
text=True,
timeout=600 # 10 minute timeout for implementation
)
if result.returncode == 0:
repo = Repo(repo_dir)
return get_git_patch(repo)
logging.error(f"ra-aid failed with exit code {result.returncode}")
logging.debug(f"stdout: {result.stdout}")
logging.debug(f"stderr: {result.stderr}")
return None
except subprocess.TimeoutExpired:
logging.error("ra-aid timed out")
logging.error("ra-aid implementation phase timed out.")
return None
except Exception as e:
logging.error(f"Error running ra-aid: {e}")
logging.error(f"ra-aid error: {e}")
return None
# Collect patch
repo = Repo(repo_dir)
patch = get_git_patch(repo)
return patch
def get_git_patch(repo: Repo) -> Optional[str]:
"""Generate a git patch from the current changes.
Args:
repo: GitPython Repo object
Returns:
Optional[str]: Formatted patch if valid changes exist
"""
"""Generate a git patch for current changes."""
if not repo.is_dirty():
logging.error("No changes detected in repository")
logging.info("No repo changes detected.")
return None
try:
# Get diff in patch format
patch = repo.git.diff(unified=3)
# Basic validation
if not patch or not patch.strip():
return None
if not any(line.startswith('+') for line in patch.splitlines()):
return None
return patch
except Exception as e:
logging.error(f"Failed to generate patch: {e}")
return None
def write_prediction(output_dir: Path, instance_id: str, patch: str) -> None:
"""Write prediction entry to JSONL file.
Args:
output_dir: Output directory path
instance_id: Dataset instance ID
patch: Generated patch content
def process_instance(instance: Dict[str, Any], output_repo_dir: Path) -> Dict[str, Any]:
"""
prediction_file = output_dir / "all_preds.jsonl"
entry = {
"id": instance_id,
"patch": patch,
"timestamp": datetime.now().isoformat(),
"metadata": {
"ra_aid_version": subprocess.check_output(
['ra-aid', '--version'],
text=True
).strip(),
"git_hash": subprocess.check_output(
['git', 'rev-parse', 'HEAD'],
text=True
).strip()
Process a single dataset instance:
- Clone the repo
- Checkout commit
- Build prompt from problem_statement, FAIL_TO_PASS, PASS_TO_PASS
- Return dict in required format:
{
"instance_id": ...,
"model_patch": ...,
"model_name_or_path": ...
}
"""
inst_id = instance.get("instance_id", "<unknown>")
repo_name = instance["repo"]
commit = instance["base_commit"]
problem_statement = instance["problem_statement"]
fail_tests = instance.get("FAIL_TO_PASS", [])
pass_tests = instance.get("PASS_TO_PASS", [])
# Convert to lists if they're strings
if isinstance(fail_tests, str):
fail_tests = [fail_tests]
if isinstance(pass_tests, str):
pass_tests = [pass_tests]
# Attempt to build a github url if not provided
# If 'repo' is "org/repo", create https://github.com/org/repo.git
if "github.com" not in repo_name:
repo_url = f"https://github.com/{repo_name}.git"
else:
repo_url = repo_name
patch_str = None
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
try:
# Clone & checkout
repo = Repo.clone_from(repo_url, tmp_path)
repo.git.checkout(commit)
except Exception as e:
logging.error(f"Failed to clone/check out {repo_url}:{commit} - {e}")
return {
"instance_id": inst_id,
"model_patch": "",
"model_name_or_path": "ra-aid"
}
# Run ra-aid
patch_str = run_raaid(tmp_path, problem_statement, fail_tests, pass_tests)
# Return required prediction structure
return {
"instance_id": inst_id,
"model_patch": patch_str if patch_str else "",
"model_name_or_path": "ra-aid"
}
with open(prediction_file, "a") as f:
json.dump(entry, f)
f.write("\n")
# Also save individual prediction files for easier inspection
instance_dir = output_dir / "predictions" / instance_id
instance_dir.mkdir(parents=True, exist_ok=True)
with open(instance_dir / "prediction.json", "w") as f:
json.dump(entry, f, indent=2)
def cleanup_temp_files(temp_dir: Path) -> None:
"""Remove temporary processing files.
Args:
temp_dir: Directory containing temporary files
"""
if temp_dir.exists():
shutil.rmtree(temp_dir)
logging.debug(f"Cleaned up temporary directory: {temp_dir}")
def parse_args() -> argparse.Namespace:
"""Parse command line arguments.
Returns:
Parsed argument namespace
"""
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate SWE-bench dataset for RA.Aid evaluation"
description="Generate predictions for SWE-bench Lite using ra-aid."
)
parser.add_argument(
"output_dir",
type=Path,
help="Directory to store processed dataset"
help="Directory to store prediction file"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose logging output"
)
parser.add_argument(
"--continue-on-error",
action="store_true",
help="Continue processing if individual instances fail"
help="Enable verbose logging"
)
parser.add_argument(
"--num-instances",
@ -376,63 +217,55 @@ def parse_args() -> argparse.Namespace:
default=None,
help="Number of instances to process (default: all)"
)
return parser.parse_args()
args = parser.parse_args()
def main() -> None:
"""Main entry point for dataset generation script."""
args = parse_args()
# Create directory structure
base_dir, log_dir = create_output_dirs()
# Initialize logging
setup_logging(log_dir, args.verbose)
logging.info("Starting dataset generation")
# Load dataset
logging.info("Starting script")
dataset = load_dataset_safely()
if dataset is None:
sys.exit(1)
# Create output directory
# Combine "dev" and "test" splits (no "train" in this dataset)
all_data = list(dataset["dev"]) + list(dataset["test"])
args.output_dir.mkdir(parents=True, exist_ok=True)
# Process dataset
predictions_file = args.output_dir / "predictions.json"
predictions = []
limit = args.num_instances if args.num_instances else len(all_data)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
TimeElapsedColumn(),
transient=False,
transient=False
) as progress:
total_instances = len(dataset['train'])
task = progress.add_task("Processing dataset...", total=total_instances)
success_count = 0
for idx, instance in enumerate(dataset['train']):
task = progress.add_task("Processing instances...", total=limit)
for i, inst in enumerate(all_data):
if i >= limit:
break
try:
if process_dataset_instance(instance, args.output_dir):
success_count += 1
pred = process_instance(inst, args.output_dir)
predictions.append(pred)
except Exception as e:
# Use instance.get() to avoid KeyError in error logging
instance_id = instance.get('id', '<unknown>')
logging.error(f"Failed to process instance {instance_id}: {e}")
logging.error(f"Error processing instance: {inst.get('instance_id', '')} - {e}")
finally:
progress.advance(task)
if args.num_instances is not None and idx + 1 >= args.num_instances:
break
progress.stop()
logging.info(f"Dataset generation complete. Processed {success_count}/{total_instances} instances successfully")
with open(predictions_file, "w", encoding="utf-8") as f:
json.dump(predictions, f, indent=2)
logging.info("Done generating predictions.")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nOperation cancelled by user")
print("\nOperation cancelled by user.")
sys.exit(1)
except Exception as e:
logging.exception("Unhandled error occurred")
sys.exit(1)
logging.exception("Unhandled error occurred.")
sys.exit(1)