Enhance batch processing with progress tracking and update AIAgent for OpenRouter detection
- Integrated tqdm for progress tracking in batch processing, replacing map with imap_unordered for improved performance. - Added base_url attribute in AIAgent to facilitate OpenRouter detection.
This commit is contained in:
parent
b66c093316
commit
6e3dbb8d8b
2 changed files with 10 additions and 2 deletions
|
|
@ -30,6 +30,7 @@ from datetime import datetime
|
||||||
from multiprocessing import Pool, Manager, Lock
|
from multiprocessing import Pool, Manager, Lock
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
import fire
|
import fire
|
||||||
|
|
||||||
from run_agent import AIAgent
|
from run_agent import AIAgent
|
||||||
|
|
@ -642,8 +643,14 @@ class BatchRunner:
|
||||||
print(f"✅ Created {len(tasks)} batch tasks")
|
print(f"✅ Created {len(tasks)} batch tasks")
|
||||||
print(f"🚀 Starting parallel batch processing...\n")
|
print(f"🚀 Starting parallel batch processing...\n")
|
||||||
|
|
||||||
# Use map to process batches in parallel
|
# Use imap_unordered with tqdm for progress tracking
|
||||||
results = pool.map(_process_batch_worker, tasks)
|
results = list(tqdm(
|
||||||
|
pool.imap_unordered(_process_batch_worker, tasks),
|
||||||
|
total=len(tasks),
|
||||||
|
desc="📦 Batches",
|
||||||
|
unit="batch",
|
||||||
|
ncols=80
|
||||||
|
))
|
||||||
|
|
||||||
# Aggregate all batch statistics and update checkpoint
|
# Aggregate all batch statistics and update checkpoint
|
||||||
all_completed_prompts = list(completed_prompts_set)
|
all_completed_prompts = list(completed_prompts_set)
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,7 @@ class AIAgent:
|
||||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||||
self.log_prefix_chars = log_prefix_chars
|
self.log_prefix_chars = log_prefix_chars
|
||||||
self.log_prefix = f"{log_prefix} " if log_prefix else ""
|
self.log_prefix = f"{log_prefix} " if log_prefix else ""
|
||||||
|
self.base_url = base_url or "" # Store for OpenRouter detection
|
||||||
|
|
||||||
# Store OpenRouter provider preferences
|
# Store OpenRouter provider preferences
|
||||||
self.providers_allowed = providers_allowed
|
self.providers_allowed = providers_allowed
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue