Make batch_runner checkpoint incremental and atomic
This commit is contained in:
parent
669e4d0297
commit
ac6d747fa6
1 changed files with 63 additions and 16 deletions
|
|
@ -29,6 +29,7 @@ from typing import List, Dict, Any, Optional, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from multiprocessing import Pool, Lock
|
from multiprocessing import Pool, Lock
|
||||||
import traceback
|
import traceback
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
@ -650,13 +651,32 @@ class BatchRunner:
|
||||||
"""
|
"""
|
||||||
checkpoint_data["last_updated"] = datetime.now().isoformat()
|
checkpoint_data["last_updated"] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
def _atomic_write():
|
||||||
|
"""Write checkpoint atomically (temp file + replace) to avoid corruption on crash."""
|
||||||
|
self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fd, tmp_path = tempfile.mkstemp(
|
||||||
|
dir=str(self.checkpoint_file.parent),
|
||||||
|
prefix='.checkpoint_',
|
||||||
|
suffix='.tmp',
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
os.replace(tmp_path, self.checkpoint_file)
|
||||||
|
except BaseException:
|
||||||
|
try:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
if lock:
|
if lock:
|
||||||
with lock:
|
with lock:
|
||||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
_atomic_write()
|
||||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
|
||||||
else:
|
else:
|
||||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
_atomic_write()
|
||||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
def _scan_completed_prompts_by_content(self) -> set:
|
def _scan_completed_prompts_by_content(self) -> set:
|
||||||
"""
|
"""
|
||||||
|
|
@ -781,13 +801,15 @@ class BatchRunner:
|
||||||
print(f" New batches created: {len(batches_to_process)}")
|
print(f" New batches created: {len(batches_to_process)}")
|
||||||
print("=" * 70 + "\n")
|
print("=" * 70 + "\n")
|
||||||
|
|
||||||
# Initialize checkpoint data (needed for saving at the end)
|
# Load existing checkpoint (so resume doesn't clobber prior progress)
|
||||||
checkpoint_data = {
|
checkpoint_data = self._load_checkpoint()
|
||||||
"run_name": self.run_name,
|
if checkpoint_data.get("run_name") != self.run_name:
|
||||||
"completed_prompts": [],
|
checkpoint_data = {
|
||||||
"batch_stats": {},
|
"run_name": self.run_name,
|
||||||
"last_updated": None
|
"completed_prompts": [],
|
||||||
}
|
"batch_stats": {},
|
||||||
|
"last_updated": None
|
||||||
|
}
|
||||||
|
|
||||||
# Prepare configuration for workers
|
# Prepare configuration for workers
|
||||||
config = {
|
config = {
|
||||||
|
|
@ -809,7 +831,7 @@ class BatchRunner:
|
||||||
}
|
}
|
||||||
|
|
||||||
# For backward compatibility, still track by index (but this is secondary to content matching)
|
# For backward compatibility, still track by index (but this is secondary to content matching)
|
||||||
completed_prompts_set = set()
|
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
|
||||||
|
|
||||||
# Aggregate statistics across all batches
|
# Aggregate statistics across all batches
|
||||||
total_tool_stats = {}
|
total_tool_stats = {}
|
||||||
|
|
@ -818,6 +840,9 @@ class BatchRunner:
|
||||||
|
|
||||||
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
|
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
|
||||||
|
|
||||||
|
# Checkpoint writes happen in the parent process; keep a lock for safety.
|
||||||
|
checkpoint_lock = Lock()
|
||||||
|
|
||||||
# Process batches in parallel
|
# Process batches in parallel
|
||||||
with Pool(processes=self.num_workers) as pool:
|
with Pool(processes=self.num_workers) as pool:
|
||||||
# Create tasks for each batch
|
# Create tasks for each batch
|
||||||
|
|
@ -863,6 +888,25 @@ class BatchRunner:
|
||||||
for result in pool.imap_unordered(_process_batch_worker, tasks):
|
for result in pool.imap_unordered(_process_batch_worker, tasks):
|
||||||
results.append(result)
|
results.append(result)
|
||||||
progress.update(task, advance=1)
|
progress.update(task, advance=1)
|
||||||
|
|
||||||
|
# Incremental checkpoint update (so resume works after crash)
|
||||||
|
try:
|
||||||
|
batch_num = result.get('batch_num')
|
||||||
|
completed = result.get('completed_prompts', []) or []
|
||||||
|
completed_prompts_set.update(completed)
|
||||||
|
|
||||||
|
if isinstance(batch_num, int):
|
||||||
|
checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = {
|
||||||
|
'processed': result.get('processed', 0),
|
||||||
|
'skipped': result.get('skipped', 0),
|
||||||
|
'discarded_no_reasoning': result.get('discarded_no_reasoning', 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint_data['completed_prompts'] = sorted(completed_prompts_set)
|
||||||
|
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
|
||||||
|
except Exception as ckpt_err:
|
||||||
|
# Don't fail the run if checkpoint write fails
|
||||||
|
print(f"âš ï¸ Warning: Failed to save incremental checkpoint: {ckpt_err}")
|
||||||
finally:
|
finally:
|
||||||
root_logger.setLevel(original_level)
|
root_logger.setLevel(original_level)
|
||||||
|
|
||||||
|
|
@ -891,9 +935,12 @@ class BatchRunner:
|
||||||
for key in total_reasoning_stats:
|
for key in total_reasoning_stats:
|
||||||
total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
|
total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
|
||||||
|
|
||||||
# Save final checkpoint
|
# Save final checkpoint (best-effort; incremental writes already happened)
|
||||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
try:
|
||||||
self._save_checkpoint(checkpoint_data)
|
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||||
|
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
|
||||||
|
except Exception as ckpt_err:
|
||||||
|
print(f"âš ï¸ Warning: Failed to save final checkpoint: {ckpt_err}")
|
||||||
|
|
||||||
# Calculate success rates
|
# Calculate success rates
|
||||||
for tool_name in total_tool_stats:
|
for tool_name in total_tool_stats:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue