Enhance RL test inference with WandB integration and real-time output streaming
- Added unique run ID generation for WandB tracking during test inference. - Enabled WandB usage for test tracking and updated command-line arguments accordingly. - Implemented real-time output streaming for process execution, improving log visibility and debugging. - Enhanced error handling to display last few lines of stderr for better troubleshooting.
This commit is contained in:
parent
3c0d0dba49
commit
5c3105b437
1 changed files with 48 additions and 19 deletions
|
|
@ -1093,6 +1093,10 @@ async def rl_test_inference(
|
||||||
# Output file for this test run
|
# Output file for this test run
|
||||||
output_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.jsonl"
|
output_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.jsonl"
|
||||||
|
|
||||||
|
# Generate unique run ID for wandb
|
||||||
|
test_run_id = str(uuid.uuid4())[:8]
|
||||||
|
wandb_run_name = f"test_inference_RSIAgent_{_current_env}_{test_run_id}"
|
||||||
|
|
||||||
# Build the process command using Atropos's built-in CLI
|
# Build the process command using Atropos's built-in CLI
|
||||||
# This runs the environment's actual code with OpenRouter as the inference backend
|
# This runs the environment's actual code with OpenRouter as the inference backend
|
||||||
# We pass our locked settings + test-specific overrides via CLI args
|
# We pass our locked settings + test-specific overrides via CLI args
|
||||||
|
|
@ -1101,7 +1105,8 @@ async def rl_test_inference(
|
||||||
# Test-specific overrides
|
# Test-specific overrides
|
||||||
"--env.total_steps", str(num_steps),
|
"--env.total_steps", str(num_steps),
|
||||||
"--env.group_size", str(group_size),
|
"--env.group_size", str(group_size),
|
||||||
"--env.use_wandb", "false", # No wandb for quick tests
|
"--env.use_wandb", "true", # Enable wandb for test tracking
|
||||||
|
"--env.wandb_name", wandb_run_name,
|
||||||
"--env.data_path_to_save_groups", str(output_file),
|
"--env.data_path_to_save_groups", str(output_file),
|
||||||
# Use locked settings from our config
|
# Use locked settings from our config
|
||||||
"--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"],
|
"--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"],
|
||||||
|
|
@ -1124,12 +1129,14 @@ async def rl_test_inference(
|
||||||
cmd_display = cmd_str.replace(api_key, "***API_KEY***")
|
cmd_display = cmd_str.replace(api_key, "***API_KEY***")
|
||||||
print(f"Command: {cmd_display}")
|
print(f"Command: {cmd_display}")
|
||||||
print(f"Working dir: {TINKER_ATROPOS_ROOT}")
|
print(f"Working dir: {TINKER_ATROPOS_ROOT}")
|
||||||
|
print(f"WandB run: {wandb_run_name}")
|
||||||
print(f" {num_steps} steps × {group_size} completions = {total_rollouts_per_model} rollouts")
|
print(f" {num_steps} steps × {group_size} completions = {total_rollouts_per_model} rollouts")
|
||||||
|
|
||||||
model_results = {
|
model_results = {
|
||||||
"model": model_id,
|
"model": model_id,
|
||||||
"name": model_info["name"],
|
"name": model_info["name"],
|
||||||
"scale": model_info["scale"],
|
"scale": model_info["scale"],
|
||||||
|
"wandb_run": wandb_run_name,
|
||||||
"output_file": str(output_file),
|
"output_file": str(output_file),
|
||||||
"steps": [],
|
"steps": [],
|
||||||
"steps_tested": 0,
|
"steps_tested": 0,
|
||||||
|
|
@ -1138,7 +1145,7 @@ async def rl_test_inference(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Run the process command
|
# Run the process command with real-time output streaming
|
||||||
process = await asyncio.create_subprocess_exec(
|
process = await asyncio.create_subprocess_exec(
|
||||||
*cmd,
|
*cmd,
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
|
@ -1146,17 +1153,43 @@ async def rl_test_inference(
|
||||||
cwd=str(TINKER_ATROPOS_ROOT),
|
cwd=str(TINKER_ATROPOS_ROOT),
|
||||||
)
|
)
|
||||||
|
|
||||||
stdout, stderr = await asyncio.wait_for(
|
# Stream output in real-time while collecting for logs
|
||||||
process.communicate(),
|
stdout_lines = []
|
||||||
|
stderr_lines = []
|
||||||
|
log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log"
|
||||||
|
|
||||||
|
async def read_stream(stream, lines_list, prefix=""):
|
||||||
|
"""Read stream line by line and print in real-time."""
|
||||||
|
while True:
|
||||||
|
line = await stream.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
decoded = line.decode().rstrip()
|
||||||
|
lines_list.append(decoded)
|
||||||
|
# Print progress-related lines in real-time
|
||||||
|
if any(kw in decoded.lower() for kw in ['processing', 'group', 'step', 'progress', '%', 'completed']):
|
||||||
|
print(f" {prefix}{decoded}")
|
||||||
|
|
||||||
|
# Read both streams concurrently with timeout
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
asyncio.gather(
|
||||||
|
read_stream(process.stdout, stdout_lines, "📊 "),
|
||||||
|
read_stream(process.stderr, stderr_lines, "⚠️ "),
|
||||||
|
),
|
||||||
timeout=600, # 10 minute timeout per model
|
timeout=600, # 10 minute timeout per model
|
||||||
)
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
process.kill()
|
||||||
|
raise
|
||||||
|
|
||||||
# Decode output
|
await process.wait()
|
||||||
stdout_text = stdout.decode() if stdout else ""
|
|
||||||
stderr_text = stderr.decode() if stderr else ""
|
# Combine output for logging
|
||||||
|
stdout_text = "\n".join(stdout_lines)
|
||||||
|
stderr_text = "\n".join(stderr_lines)
|
||||||
|
|
||||||
# Write logs to files for inspection outside CLI
|
# Write logs to files for inspection outside CLI
|
||||||
log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log"
|
|
||||||
with open(log_file, "w") as f:
|
with open(log_file, "w") as f:
|
||||||
f.write(f"Command: {cmd_display}\n")
|
f.write(f"Command: {cmd_display}\n")
|
||||||
f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n")
|
f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n")
|
||||||
|
|
@ -1170,21 +1203,17 @@ async def rl_test_inference(
|
||||||
|
|
||||||
print(f" Log file: {log_file}")
|
print(f" Log file: {log_file}")
|
||||||
|
|
||||||
# Print to console for immediate debugging
|
|
||||||
if stdout_text.strip():
|
|
||||||
print(f"\n--- STDOUT ---")
|
|
||||||
print(stdout_text[-2000:]) # Last 2000 chars
|
|
||||||
|
|
||||||
if stderr_text.strip():
|
|
||||||
print(f"\n--- STDERR ---")
|
|
||||||
print(stderr_text[-2000:]) # Last 2000 chars
|
|
||||||
|
|
||||||
if process.returncode != 0:
|
if process.returncode != 0:
|
||||||
model_results["error"] = f"Process exited with code {process.returncode}"
|
model_results["error"] = f"Process exited with code {process.returncode}"
|
||||||
model_results["stderr"] = stderr_text[-1000:]
|
model_results["stderr"] = stderr_text[-1000:]
|
||||||
model_results["stdout"] = stdout_text[-1000:]
|
model_results["stdout"] = stdout_text[-1000:]
|
||||||
model_results["log_file"] = str(log_file)
|
model_results["log_file"] = str(log_file)
|
||||||
print(f"\n ❌ Error: {model_results['error']}")
|
print(f"\n ❌ Error: {model_results['error']}")
|
||||||
|
# Print last few lines of stderr for debugging
|
||||||
|
if stderr_lines:
|
||||||
|
print(f" Last errors:")
|
||||||
|
for line in stderr_lines[-5:]:
|
||||||
|
print(f" {line}")
|
||||||
else:
|
else:
|
||||||
print(f"\n ✅ Process completed successfully")
|
print(f"\n ✅ Process completed successfully")
|
||||||
print(f" Output file: {output_file}")
|
print(f" Output file: {output_file}")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue