The architecture has been updated
This commit is contained in:
parent
805f7a017e
commit
a01257ead9
1119 changed files with 226 additions and 352 deletions
0
hermes_code/tests/integration/__init__.py
Normal file
0
hermes_code/tests/integration/__init__.py
Normal file
132
hermes_code/tests/integration/test_batch_runner.py
Normal file
132
hermes_code/tests/integration/test_batch_runner.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for batch runner
|
||||
|
||||
This script tests the batch runner with a small sample dataset
|
||||
to verify functionality before running large batches.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def create_test_dataset():
|
||||
"""Create a small test dataset."""
|
||||
test_file = Path("tests/test_dataset.jsonl")
|
||||
test_file.parent.mkdir(exist_ok=True)
|
||||
|
||||
prompts = [
|
||||
{"prompt": "What is 2 + 2?"},
|
||||
{"prompt": "What is the capital of France?"},
|
||||
{"prompt": "Explain what Python is in one sentence."},
|
||||
]
|
||||
|
||||
with open(test_file, 'w') as f:
|
||||
for prompt in prompts:
|
||||
f.write(json.dumps(prompt, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"✅ Created test dataset: {test_file}")
|
||||
return test_file
|
||||
|
||||
|
||||
def cleanup_test_run(run_name):
|
||||
"""Clean up test run output."""
|
||||
output_dir = Path("data") / run_name
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
print(f"🗑️ Cleaned up test output: {output_dir}")
|
||||
|
||||
|
||||
def verify_output(run_name):
|
||||
"""Verify that output files were created correctly."""
|
||||
output_dir = Path("data") / run_name
|
||||
|
||||
# Check directory exists
|
||||
if not output_dir.exists():
|
||||
print(f"❌ Output directory not found: {output_dir}")
|
||||
return False
|
||||
|
||||
# Check for checkpoint
|
||||
checkpoint_file = output_dir / "checkpoint.json"
|
||||
if not checkpoint_file.exists():
|
||||
print(f"❌ Checkpoint file not found: {checkpoint_file}")
|
||||
return False
|
||||
|
||||
# Check for statistics
|
||||
stats_file = output_dir / "statistics.json"
|
||||
if not stats_file.exists():
|
||||
print(f"❌ Statistics file not found: {stats_file}")
|
||||
return False
|
||||
|
||||
# Check for batch files
|
||||
batch_files = list(output_dir.glob("batch_*.jsonl"))
|
||||
if not batch_files:
|
||||
print(f"❌ No batch files found in: {output_dir}")
|
||||
return False
|
||||
|
||||
print(f"✅ Output verification passed:")
|
||||
print(f" - Checkpoint: {checkpoint_file}")
|
||||
print(f" - Statistics: {stats_file}")
|
||||
print(f" - Batch files: {len(batch_files)}")
|
||||
|
||||
# Load and display statistics
|
||||
with open(stats_file) as f:
|
||||
stats = json.load(f)
|
||||
|
||||
print(f"\n📊 Statistics Summary:")
|
||||
print(f" - Total prompts: {stats['total_prompts']}")
|
||||
print(f" - Total batches: {stats['total_batches']}")
|
||||
print(f" - Duration: {stats['duration_seconds']}s")
|
||||
|
||||
if stats.get('tool_statistics'):
|
||||
print(f" - Tool calls:")
|
||||
for tool, tool_stats in stats['tool_statistics'].items():
|
||||
print(f" • {tool}: {tool_stats['count']} calls, {tool_stats['success_rate']:.1f}% success")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the test."""
|
||||
print("🧪 Batch Runner Test")
|
||||
print("=" * 60)
|
||||
|
||||
run_name = "test_run"
|
||||
|
||||
# Clean up any previous test run
|
||||
cleanup_test_run(run_name)
|
||||
|
||||
# Create test dataset
|
||||
test_file = create_test_dataset()
|
||||
|
||||
print(f"\n📝 To run the test manually:")
|
||||
print(f" python batch_runner.py \\")
|
||||
print(f" --dataset_file={test_file} \\")
|
||||
print(f" --batch_size=2 \\")
|
||||
print(f" --run_name={run_name} \\")
|
||||
print(f" --distribution=minimal \\")
|
||||
print(f" --num_workers=2")
|
||||
|
||||
print(f"\n💡 Or test with different distributions:")
|
||||
print(f" python batch_runner.py --list_distributions")
|
||||
|
||||
print(f"\n🔍 After running, you can verify output with:")
|
||||
print(f" python tests/test_batch_runner.py --verify")
|
||||
|
||||
# Note: We don't actually run the batch runner here to avoid API calls during testing
|
||||
# Users should run it manually with their API keys configured
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if "--verify" in sys.argv:
|
||||
run_name = "test_run"
|
||||
verify_output(run_name)
|
||||
else:
|
||||
main()
|
||||
|
||||
440
hermes_code/tests/integration/test_checkpoint_resumption.py
Normal file
440
hermes_code/tests/integration/test_checkpoint_resumption.py
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify checkpoint behavior in batch_runner.py
|
||||
|
||||
This script simulates batch processing with intentional failures to test:
|
||||
1. Whether checkpoints are saved incrementally during processing
|
||||
2. Whether resume functionality works correctly after interruption
|
||||
3. Whether data integrity is maintained across checkpoint cycles
|
||||
|
||||
Usage:
|
||||
# Test current implementation
|
||||
python tests/test_checkpoint_resumption.py --test_current
|
||||
|
||||
# Test after fix is applied
|
||||
python tests/test_checkpoint_resumption.py --test_fixed
|
||||
|
||||
# Run full comparison
|
||||
python tests/test_checkpoint_resumption.py --compare
|
||||
"""
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
import traceback
|
||||
|
||||
# Add project root to path to import batch_runner
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
|
||||
def create_test_dataset(num_prompts: int = 20) -> Path:
|
||||
"""Create a small test dataset for checkpoint testing."""
|
||||
test_data_dir = Path("tests/test_data")
|
||||
test_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dataset_file = test_data_dir / "checkpoint_test_dataset.jsonl"
|
||||
|
||||
with open(dataset_file, 'w', encoding='utf-8') as f:
|
||||
for i in range(num_prompts):
|
||||
entry = {
|
||||
"prompt": f"Test prompt {i}: What is 2+2? Just answer briefly.",
|
||||
"test_id": i
|
||||
}
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"✅ Created test dataset: {dataset_file} ({num_prompts} prompts)")
|
||||
return dataset_file
|
||||
|
||||
|
||||
def monitor_checkpoint_during_run(checkpoint_file: Path, duration: int = 30) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Monitor checkpoint file during a batch run to see when it gets updated.
|
||||
|
||||
Args:
|
||||
checkpoint_file: Path to checkpoint file to monitor
|
||||
duration: How long to monitor (seconds)
|
||||
|
||||
Returns:
|
||||
List of checkpoint snapshots with timestamps
|
||||
"""
|
||||
snapshots = []
|
||||
start_time = time.time()
|
||||
last_mtime = None
|
||||
|
||||
print(f"\n🔍 Monitoring checkpoint file: {checkpoint_file}")
|
||||
print(f" Duration: {duration}s")
|
||||
print("-" * 70)
|
||||
|
||||
while time.time() - start_time < duration:
|
||||
if checkpoint_file.exists():
|
||||
current_mtime = checkpoint_file.stat().st_mtime
|
||||
|
||||
# Check if file was modified
|
||||
if last_mtime is None or current_mtime != last_mtime:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
try:
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
checkpoint_data = json.load(f)
|
||||
|
||||
snapshot = {
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"completed_count": len(checkpoint_data.get("completed_prompts", [])),
|
||||
"completed_prompts": checkpoint_data.get("completed_prompts", [])[:5], # First 5 for display
|
||||
"timestamp": checkpoint_data.get("last_updated")
|
||||
}
|
||||
|
||||
snapshots.append(snapshot)
|
||||
|
||||
print(f"[{elapsed:6.2f}s] Checkpoint updated: {snapshot['completed_count']} prompts completed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{elapsed:6.2f}s] Error reading checkpoint: {e}")
|
||||
|
||||
last_mtime = current_mtime
|
||||
else:
|
||||
if len(snapshots) == 0:
|
||||
print(f"[{time.time() - start_time:6.2f}s] Checkpoint file not yet created...")
|
||||
|
||||
time.sleep(0.5) # Check every 0.5 seconds
|
||||
|
||||
return snapshots
|
||||
|
||||
|
||||
def _cleanup_test_artifacts(*paths):
|
||||
"""Remove test-generated files and directories."""
|
||||
for p in paths:
|
||||
p = Path(p)
|
||||
if p.is_dir():
|
||||
shutil.rmtree(p, ignore_errors=True)
|
||||
elif p.is_file():
|
||||
p.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_current_implementation():
|
||||
"""Test the current checkpoint implementation."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 1: Current Implementation - Checkpoint Timing")
|
||||
print("=" * 70)
|
||||
print("\n📝 Testing whether checkpoints are saved incrementally during run...")
|
||||
|
||||
# Setup
|
||||
dataset_file = create_test_dataset(num_prompts=12)
|
||||
run_name = "checkpoint_test_current"
|
||||
output_dir = Path("data") / run_name
|
||||
|
||||
# Clean up any existing test data
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
# Import here to avoid issues if module changes
|
||||
from batch_runner import BatchRunner
|
||||
|
||||
checkpoint_file = output_dir / "checkpoint.json"
|
||||
|
||||
# Start monitoring in a separate process would be ideal, but for simplicity
|
||||
# we'll just check before and after
|
||||
print(f"\n▶️ Starting batch run...")
|
||||
print(f" Dataset: {dataset_file}")
|
||||
print(f" Batch size: 3 (4 batches total)")
|
||||
print(f" Workers: 2")
|
||||
print(f" Expected behavior: If incremental, checkpoint should update during run")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
runner = BatchRunner(
|
||||
dataset_file=str(dataset_file),
|
||||
batch_size=3,
|
||||
run_name=run_name,
|
||||
distribution="default",
|
||||
max_iterations=3, # Keep it short
|
||||
model="claude-opus-4-20250514",
|
||||
num_workers=2,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Run with monitoring
|
||||
import threading
|
||||
snapshots = []
|
||||
|
||||
def monitor():
|
||||
nonlocal snapshots
|
||||
snapshots = monitor_checkpoint_during_run(checkpoint_file, duration=60)
|
||||
|
||||
monitor_thread = threading.Thread(target=monitor, daemon=True)
|
||||
monitor_thread.start()
|
||||
|
||||
runner.run(resume=False)
|
||||
|
||||
monitor_thread.join(timeout=2)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during run: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
_cleanup_test_artifacts(dataset_file, output_dir)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Analyze results
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 TEST RESULTS")
|
||||
print("=" * 70)
|
||||
print(f"Total run time: {elapsed:.2f}s")
|
||||
print(f"Checkpoint updates observed: {len(snapshots)}")
|
||||
|
||||
if len(snapshots) == 0:
|
||||
print("\n❌ ISSUE: No checkpoint updates observed during run")
|
||||
print(" This suggests checkpoints are only saved at the end")
|
||||
return False
|
||||
elif len(snapshots) == 1:
|
||||
print("\n⚠️ WARNING: Only 1 checkpoint update (likely at the end)")
|
||||
print(" This confirms the bug - no incremental checkpointing")
|
||||
return False
|
||||
else:
|
||||
print(f"\n✅ GOOD: Multiple checkpoint updates ({len(snapshots)}) observed")
|
||||
print(" Checkpointing appears to be incremental")
|
||||
|
||||
# Show timeline
|
||||
print("\n📈 Checkpoint Timeline:")
|
||||
for i, snapshot in enumerate(snapshots, 1):
|
||||
print(f" {i}. [{snapshot['elapsed_seconds']:6.2f}s] "
|
||||
f"{snapshot['completed_count']} prompts completed")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_interruption_and_resume():
|
||||
"""Test that resume actually works after interruption."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 2: Interruption and Resume")
|
||||
print("=" * 70)
|
||||
print("\n📝 Testing whether resume works after manual interruption...")
|
||||
|
||||
# Setup
|
||||
dataset_file = create_test_dataset(num_prompts=15)
|
||||
run_name = "checkpoint_test_resume"
|
||||
output_dir = Path("data") / run_name
|
||||
|
||||
# Clean up any existing test data
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
from batch_runner import BatchRunner
|
||||
|
||||
checkpoint_file = output_dir / "checkpoint.json"
|
||||
|
||||
print(f"\n▶️ Starting first run (will process 5 prompts, then simulate interruption)...")
|
||||
|
||||
temp_dataset = Path("tests/test_data/checkpoint_test_resume_partial.jsonl")
|
||||
try:
|
||||
# Create a modified dataset with only first 5 prompts for initial run
|
||||
with open(dataset_file, 'r') as f:
|
||||
lines = f.readlines()[:5]
|
||||
with open(temp_dataset, 'w') as f:
|
||||
f.writelines(lines)
|
||||
|
||||
runner = BatchRunner(
|
||||
dataset_file=str(temp_dataset),
|
||||
batch_size=2,
|
||||
run_name=run_name,
|
||||
distribution="default",
|
||||
max_iterations=3,
|
||||
model="claude-opus-4-20250514",
|
||||
num_workers=1,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
runner.run(resume=False)
|
||||
|
||||
# Check checkpoint after first run
|
||||
if not checkpoint_file.exists():
|
||||
print("❌ ERROR: Checkpoint file not created after first run")
|
||||
return False
|
||||
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
checkpoint_data = json.load(f)
|
||||
|
||||
initial_completed = len(checkpoint_data.get("completed_prompts", []))
|
||||
print(f"✅ First run completed: {initial_completed} prompts saved to checkpoint")
|
||||
|
||||
# Now try to resume with full dataset
|
||||
print(f"\n▶️ Starting resume run with full dataset (15 prompts)...")
|
||||
|
||||
runner2 = BatchRunner(
|
||||
dataset_file=str(dataset_file),
|
||||
batch_size=2,
|
||||
run_name=run_name,
|
||||
distribution="default",
|
||||
max_iterations=3,
|
||||
model="claude-opus-4-20250514",
|
||||
num_workers=1,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
runner2.run(resume=True)
|
||||
|
||||
# Check final checkpoint
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
final_checkpoint = json.load(f)
|
||||
|
||||
final_completed = len(final_checkpoint.get("completed_prompts", []))
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 TEST RESULTS")
|
||||
print("=" * 70)
|
||||
print(f"Initial completed: {initial_completed}")
|
||||
print(f"Final completed: {final_completed}")
|
||||
print(f"Expected: 15")
|
||||
|
||||
if final_completed == 15:
|
||||
print("\n✅ PASS: Resume successfully completed all prompts")
|
||||
return True
|
||||
else:
|
||||
print(f"\n❌ FAIL: Expected 15 completed, got {final_completed}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during test: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
_cleanup_test_artifacts(dataset_file, temp_dataset, output_dir)
|
||||
|
||||
|
||||
def test_simulated_crash():
|
||||
"""Test behavior when process crashes mid-execution."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 3: Simulated Crash During Execution")
|
||||
print("=" * 70)
|
||||
print("\n📝 This test would require running in a subprocess and killing it...")
|
||||
print(" Skipping for safety - manual testing recommended")
|
||||
return None
|
||||
|
||||
|
||||
def print_test_plan():
|
||||
"""Print the detailed test and fix plan."""
|
||||
print("\n" + "=" * 70)
|
||||
print("CHECKPOINT FIX - DETAILED PLAN")
|
||||
print("=" * 70)
|
||||
|
||||
print("""
|
||||
📋 PROBLEM SUMMARY
|
||||
------------------
|
||||
Current implementation uses pool.map() which blocks until ALL batches complete.
|
||||
Checkpoint is only saved after all batches finish (line 558-559).
|
||||
|
||||
If process crashes during batch processing:
|
||||
- All progress is lost
|
||||
- Resume does nothing (no incremental checkpoint was saved)
|
||||
|
||||
📋 PROPOSED SOLUTION
|
||||
--------------------
|
||||
Replace pool.map() with pool.imap_unordered() to get results as they complete.
|
||||
Save checkpoint after EACH batch completes using a multiprocessing Lock.
|
||||
|
||||
Key changes:
|
||||
1. Use Manager().Lock() for thread-safe checkpoint writes
|
||||
2. Replace pool.map() with pool.imap_unordered()
|
||||
3. Update checkpoint after each batch result
|
||||
4. Maintain backward compatibility with existing checkpoints
|
||||
|
||||
📋 IMPLEMENTATION STEPS
|
||||
-----------------------
|
||||
1. Add Manager and Lock initialization before Pool creation
|
||||
2. Pass shared checkpoint data and lock to workers (via Manager)
|
||||
3. Replace pool.map() with pool.imap_unordered()
|
||||
4. In result loop: save checkpoint after each batch
|
||||
5. Add error handling for checkpoint write failures
|
||||
|
||||
📋 RISKS & MITIGATIONS
|
||||
----------------------
|
||||
Risk: Checkpoint file corruption if two processes write simultaneously
|
||||
→ Mitigation: Use multiprocessing.Lock() for exclusive access
|
||||
|
||||
Risk: Performance impact from frequent checkpoint writes
|
||||
→ Mitigation: Checkpoint writes are fast (small JSON), negligible impact
|
||||
|
||||
Risk: Breaking existing runs that are already checkpointed
|
||||
→ Mitigation: Maintain checkpoint format, only change timing
|
||||
|
||||
Risk: Bugs in multiprocessing lock/manager code
|
||||
→ Mitigation: Thorough testing with this test script
|
||||
|
||||
📋 TESTING STRATEGY
|
||||
-------------------
|
||||
1. Run test_current_implementation() - Confirm bug exists
|
||||
2. Apply fix to batch_runner.py
|
||||
3. Run test_current_implementation() again - Should see incremental updates
|
||||
4. Run test_interruption_and_resume() - Verify resume works
|
||||
5. Manual test: Start run, kill process mid-batch, resume
|
||||
|
||||
📋 ROLLBACK PLAN
|
||||
----------------
|
||||
If issues arise:
|
||||
1. Git revert the changes
|
||||
2. Original code is working (just missing incremental checkpoint)
|
||||
3. No data corruption risk - checkpoints are write-only
|
||||
""")
|
||||
|
||||
|
||||
def main(
|
||||
test_current: bool = False,
|
||||
test_resume: bool = False,
|
||||
test_crash: bool = False,
|
||||
compare: bool = False,
|
||||
show_plan: bool = False
|
||||
):
|
||||
"""
|
||||
Run checkpoint behavior tests.
|
||||
|
||||
Args:
|
||||
test_current: Test current implementation checkpoint timing
|
||||
test_resume: Test interruption and resume functionality
|
||||
test_crash: Test simulated crash scenario (manual)
|
||||
compare: Run all tests and compare
|
||||
show_plan: Show detailed fix plan
|
||||
"""
|
||||
if show_plan or (not any([test_current, test_resume, test_crash, compare])):
|
||||
print_test_plan()
|
||||
return
|
||||
|
||||
results = {}
|
||||
|
||||
if test_current or compare:
|
||||
results['current'] = test_current_implementation()
|
||||
|
||||
if test_resume or compare:
|
||||
results['resume'] = test_interruption_and_resume()
|
||||
|
||||
if test_crash or compare:
|
||||
results['crash'] = test_simulated_crash()
|
||||
|
||||
# Summary
|
||||
if results:
|
||||
print("\n" + "=" * 70)
|
||||
print("OVERALL TEST SUMMARY")
|
||||
print("=" * 70)
|
||||
for test_name, result in results.items():
|
||||
if result is None:
|
||||
status = "⏭️ SKIPPED"
|
||||
elif result:
|
||||
status = "✅ PASS"
|
||||
else:
|
||||
status = "❌ FAIL"
|
||||
print(f"{status} - {test_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
fire.Fire(main)
|
||||
|
||||
123
hermes_code/tests/integration/test_daytona_terminal.py
Normal file
123
hermes_code/tests/integration/test_daytona_terminal.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""Integration tests for the Daytona terminal backend.
|
||||
|
||||
Requires DAYTONA_API_KEY to be set. Run with:
|
||||
TERMINAL_ENV=daytona pytest tests/integration/test_daytona_terminal.py -v
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
# Skip entire module if no API key
|
||||
if not os.getenv("DAYTONA_API_KEY"):
|
||||
pytest.skip("DAYTONA_API_KEY not set", allow_module_level=True)
|
||||
|
||||
# Import terminal_tool via importlib to avoid tools/__init__.py side effects
|
||||
import importlib.util
|
||||
|
||||
parent_dir = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"terminal_tool", parent_dir / "tools" / "terminal_tool.py"
|
||||
)
|
||||
terminal_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(terminal_module)
|
||||
|
||||
terminal_tool = terminal_module.terminal_tool
|
||||
cleanup_vm = terminal_module.cleanup_vm
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _force_daytona(monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_ENV", "daytona")
|
||||
monkeypatch.setenv("TERMINAL_CONTAINER_DISK", "10240")
|
||||
monkeypatch.setenv("TERMINAL_CONTAINER_PERSISTENT", "false")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def task_id(request):
|
||||
"""Provide a unique task_id and clean up the sandbox after the test."""
|
||||
tid = f"daytona_test_{request.node.name}"
|
||||
yield tid
|
||||
cleanup_vm(tid)
|
||||
|
||||
|
||||
def _run(command, task_id, **kwargs):
|
||||
result = terminal_tool(command, task_id=task_id, **kwargs)
|
||||
return json.loads(result)
|
||||
|
||||
|
||||
class TestDaytonaBasic:
|
||||
def test_echo(self, task_id):
|
||||
r = _run("echo 'Hello from Daytona!'", task_id)
|
||||
assert r["exit_code"] == 0
|
||||
assert "Hello from Daytona!" in r["output"]
|
||||
|
||||
def test_python_version(self, task_id):
|
||||
r = _run("python3 --version", task_id)
|
||||
assert r["exit_code"] == 0
|
||||
assert "Python" in r["output"]
|
||||
|
||||
def test_nonzero_exit(self, task_id):
|
||||
r = _run("exit 42", task_id)
|
||||
assert r["exit_code"] == 42
|
||||
|
||||
def test_os_info(self, task_id):
|
||||
r = _run("uname -a", task_id)
|
||||
assert r["exit_code"] == 0
|
||||
assert "Linux" in r["output"]
|
||||
|
||||
|
||||
class TestDaytonaFilesystem:
|
||||
def test_write_and_read_file(self, task_id):
|
||||
_run("echo 'test content' > /tmp/daytona_test.txt", task_id)
|
||||
r = _run("cat /tmp/daytona_test.txt", task_id)
|
||||
assert r["exit_code"] == 0
|
||||
assert "test content" in r["output"]
|
||||
|
||||
def test_persistence_within_session(self, task_id):
|
||||
_run("pip install cowsay 2>/dev/null", task_id, timeout=120)
|
||||
r = _run('python3 -c "import cowsay; print(cowsay.__file__)"', task_id)
|
||||
assert r["exit_code"] == 0
|
||||
assert "cowsay" in r["output"]
|
||||
|
||||
|
||||
class TestDaytonaPersistence:
|
||||
def test_filesystem_survives_stop_and_resume(self):
|
||||
"""Write a file, stop the sandbox, resume it, assert the file persists."""
|
||||
task = "daytona_test_persist"
|
||||
try:
|
||||
# Enable persistence for this test
|
||||
os.environ["TERMINAL_CONTAINER_PERSISTENT"] = "true"
|
||||
|
||||
# Write a marker file and stop the sandbox
|
||||
_run("echo 'survive' > /tmp/persist_test.txt", task)
|
||||
cleanup_vm(task) # stops (not deletes) because persistent=true
|
||||
|
||||
# Resume with the same task_id — file should still exist
|
||||
r = _run("cat /tmp/persist_test.txt", task)
|
||||
assert r["exit_code"] == 0
|
||||
assert "survive" in r["output"]
|
||||
finally:
|
||||
# Force-delete so the sandbox doesn't leak
|
||||
os.environ["TERMINAL_CONTAINER_PERSISTENT"] = "false"
|
||||
cleanup_vm(task)
|
||||
|
||||
|
||||
class TestDaytonaIsolation:
|
||||
def test_different_tasks_isolated(self):
|
||||
task_a = "daytona_test_iso_a"
|
||||
task_b = "daytona_test_iso_b"
|
||||
try:
|
||||
_run("echo 'secret' > /tmp/isolated.txt", task_a)
|
||||
r = _run("cat /tmp/isolated.txt 2>&1 || echo NOT_FOUND", task_b)
|
||||
assert "secret" not in r["output"] or "NOT_FOUND" in r["output"]
|
||||
finally:
|
||||
cleanup_vm(task_a)
|
||||
cleanup_vm(task_b)
|
||||
341
hermes_code/tests/integration/test_ha_integration.py
Normal file
341
hermes_code/tests/integration/test_ha_integration.py
Normal file
|
|
@ -0,0 +1,341 @@
|
|||
"""Integration tests for Home Assistant (tool + gateway).
|
||||
|
||||
Spins up a real in-process fake HA server (HTTP + WebSocket) and exercises
|
||||
the full adapter and tool handler paths over real TCP connections.
|
||||
No mocks -- only real async I/O against a fake server.
|
||||
|
||||
Run with: uv run pytest tests/integration/test_ha_integration.py -v
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.homeassistant import HomeAssistantAdapter
|
||||
from tests.fakes.fake_ha_server import FakeHAServer, ENTITY_STATES
|
||||
from tools.homeassistant_tool import (
|
||||
_async_call_service,
|
||||
_async_get_state,
|
||||
_async_list_entities,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _adapter_for(server: FakeHAServer, **extra) -> HomeAssistantAdapter:
|
||||
"""Create an adapter pointed at the fake server."""
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token=server.token,
|
||||
extra={"url": server.url, **extra},
|
||||
)
|
||||
return HomeAssistantAdapter(config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Gateway -- WebSocket lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGatewayWebSocket:
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_auth_subscribe(self):
|
||||
"""Full WS handshake succeeds: auth_required -> auth -> auth_ok -> subscribe -> ACK."""
|
||||
async with FakeHAServer() as server:
|
||||
adapter = _adapter_for(server)
|
||||
connected = await adapter.connect()
|
||||
assert connected is True
|
||||
assert adapter._running is True
|
||||
assert adapter._ws is not None
|
||||
assert not adapter._ws.closed
|
||||
await adapter.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_auth_rejected(self):
|
||||
"""connect() returns False when the server rejects auth."""
|
||||
async with FakeHAServer() as server:
|
||||
server.reject_auth = True
|
||||
adapter = _adapter_for(server)
|
||||
connected = await adapter.connect()
|
||||
assert connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_received_and_forwarded(self):
|
||||
"""Server pushes event -> adapter calls handle_message with correct MessageEvent."""
|
||||
async with FakeHAServer() as server:
|
||||
adapter = _adapter_for(server)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
await adapter.connect()
|
||||
|
||||
# Push a state_changed event
|
||||
await server.push_event({
|
||||
"data": {
|
||||
"entity_id": "light.bedroom",
|
||||
"old_state": {"state": "off", "attributes": {}},
|
||||
"new_state": {
|
||||
"state": "on",
|
||||
"attributes": {"friendly_name": "Bedroom Light"},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
# Wait for the adapter to process it
|
||||
for _ in range(50):
|
||||
if adapter.handle_message.call_count > 0:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert adapter.handle_message.call_count == 1
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "Bedroom Light" in msg_event.text
|
||||
assert "turned on" in msg_event.text
|
||||
assert msg_event.source.platform == Platform.HOMEASSISTANT
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_filtering_ignores_unwatched(self):
|
||||
"""Events outside watch_domains are silently dropped."""
|
||||
async with FakeHAServer() as server:
|
||||
adapter = _adapter_for(server, watch_domains=["climate"])
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
await adapter.connect()
|
||||
|
||||
# Push a light event (not in watch_domains)
|
||||
await server.push_event({
|
||||
"data": {
|
||||
"entity_id": "light.bedroom",
|
||||
"old_state": {"state": "off", "attributes": {}},
|
||||
"new_state": {
|
||||
"state": "on",
|
||||
"attributes": {"friendly_name": "Bedroom Light"},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
assert adapter.handle_message.call_count == 0
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_closes_cleanly(self):
|
||||
"""disconnect() cancels listener and closes WebSocket."""
|
||||
async with FakeHAServer() as server:
|
||||
adapter = _adapter_for(server)
|
||||
await adapter.connect()
|
||||
ws_ref = adapter._ws
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
assert adapter._running is False
|
||||
assert adapter._listen_task is None
|
||||
assert adapter._ws is None
|
||||
# The original WS reference should be closed
|
||||
assert ws_ref.closed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. REST tool handlers (real HTTP against fake server)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolRest:
|
||||
"""Call the async tool functions directly against the fake server.
|
||||
|
||||
Note: we call ``_async_*`` instead of the sync ``_handle_*`` wrappers
|
||||
because the sync wrappers use ``_run_async`` which blocks the event
|
||||
loop, deadlocking with the in-process fake server. The async functions
|
||||
are the real logic; the sync wrappers are trivial bridge code already
|
||||
covered by unit tests.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_entities_returns_all(self, monkeypatch):
|
||||
"""_async_list_entities returns all entities from the fake server."""
|
||||
async with FakeHAServer() as server:
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_URL", server.url,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_TOKEN", server.token,
|
||||
)
|
||||
|
||||
result = await _async_list_entities()
|
||||
|
||||
assert result["count"] == len(ENTITY_STATES)
|
||||
ids = {e["entity_id"] for e in result["entities"]}
|
||||
assert "light.bedroom" in ids
|
||||
assert "climate.thermostat" in ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_entities_domain_filter(self, monkeypatch):
|
||||
"""Domain filter is applied after fetching from server."""
|
||||
async with FakeHAServer() as server:
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_URL", server.url,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_TOKEN", server.token,
|
||||
)
|
||||
|
||||
result = await _async_list_entities(domain="light")
|
||||
|
||||
assert result["count"] == 2
|
||||
for e in result["entities"]:
|
||||
assert e["entity_id"].startswith("light.")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_state_single_entity(self, monkeypatch):
|
||||
"""_async_get_state returns full entity details."""
|
||||
async with FakeHAServer() as server:
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_URL", server.url,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_TOKEN", server.token,
|
||||
)
|
||||
|
||||
result = await _async_get_state("light.bedroom")
|
||||
|
||||
assert result["entity_id"] == "light.bedroom"
|
||||
assert result["state"] == "on"
|
||||
assert result["attributes"]["brightness"] == 200
|
||||
assert result["last_changed"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_state_not_found(self, monkeypatch):
|
||||
"""Non-existent entity raises an aiohttp error (404)."""
|
||||
import aiohttp as _aiohttp
|
||||
|
||||
async with FakeHAServer() as server:
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_URL", server.url,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_TOKEN", server.token,
|
||||
)
|
||||
|
||||
with pytest.raises(_aiohttp.ClientResponseError) as exc_info:
|
||||
await _async_get_state("light.nonexistent")
|
||||
assert exc_info.value.status == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_service_turn_on(self, monkeypatch):
|
||||
"""_async_call_service sends correct payload and server records it."""
|
||||
async with FakeHAServer() as server:
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_URL", server.url,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_TOKEN", server.token,
|
||||
)
|
||||
|
||||
result = await _async_call_service(
|
||||
domain="light",
|
||||
service="turn_on",
|
||||
entity_id="light.bedroom",
|
||||
data={"brightness": 255},
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["service"] == "light.turn_on"
|
||||
assert len(result["affected_entities"]) == 1
|
||||
assert result["affected_entities"][0]["state"] == "on"
|
||||
|
||||
# Verify fake server recorded the call
|
||||
assert len(server.received_service_calls) == 1
|
||||
call = server.received_service_calls[0]
|
||||
assert call["domain"] == "light"
|
||||
assert call["service"] == "turn_on"
|
||||
assert call["data"]["entity_id"] == "light.bedroom"
|
||||
assert call["data"]["brightness"] == 255
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. send() -- REST notification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendNotification:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_notification_delivered(self):
|
||||
"""Adapter send() delivers notification to fake server REST endpoint."""
|
||||
async with FakeHAServer() as server:
|
||||
adapter = _adapter_for(server)
|
||||
|
||||
result = await adapter.send("ha_events", "Test notification from agent")
|
||||
|
||||
assert result.success is True
|
||||
assert len(server.received_notifications) == 1
|
||||
notif = server.received_notifications[0]
|
||||
assert notif["title"] == "Hermes Agent"
|
||||
assert notif["message"] == "Test notification from agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_auth_failure(self):
|
||||
"""send() returns failure when token is wrong."""
|
||||
async with FakeHAServer() as server:
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="wrong-token",
|
||||
extra={"url": server.url},
|
||||
)
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
|
||||
result = await adapter.send("ha_events", "Should fail")
|
||||
|
||||
assert result.success is False
|
||||
assert "401" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Auth and error cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthAndErrors:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_unauthorized(self, monkeypatch):
|
||||
"""Async function raises on 401 when token is wrong."""
|
||||
import aiohttp as _aiohttp
|
||||
|
||||
async with FakeHAServer() as server:
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_URL", server.url,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_TOKEN", "bad-token",
|
||||
)
|
||||
|
||||
with pytest.raises(_aiohttp.ClientResponseError) as exc_info:
|
||||
await _async_list_entities()
|
||||
assert exc_info.value.status == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_server_error(self, monkeypatch):
|
||||
"""Async function raises on 500 response."""
|
||||
import aiohttp as _aiohttp
|
||||
|
||||
async with FakeHAServer() as server:
|
||||
server.force_500 = True
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_URL", server.url,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tools.homeassistant_tool._HASS_TOKEN", server.token,
|
||||
)
|
||||
|
||||
with pytest.raises(_aiohttp.ClientResponseError) as exc_info:
|
||||
await _async_list_entities()
|
||||
assert exc_info.value.status == 500
|
||||
301
hermes_code/tests/integration/test_modal_terminal.py
Normal file
301
hermes_code/tests/integration/test_modal_terminal.py
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Modal Terminal Tool
|
||||
|
||||
This script tests that the Modal terminal backend is correctly configured
|
||||
and can execute commands in Modal sandboxes.
|
||||
|
||||
Usage:
|
||||
# Run with Modal backend
|
||||
TERMINAL_ENV=modal python tests/test_modal_terminal.py
|
||||
|
||||
# Or run directly (will use whatever TERMINAL_ENV is set in .env)
|
||||
python tests/test_modal_terminal.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Try to load .env file if python-dotenv is available
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
# Manually load .env if dotenv not available
|
||||
env_file = Path(__file__).parent.parent.parent / ".env"
|
||||
if env_file.exists():
|
||||
with open(env_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
# Remove quotes if present
|
||||
value = value.strip().strip('"').strip("'")
|
||||
os.environ.setdefault(key.strip(), value)
|
||||
|
||||
# Add project root to path for imports
|
||||
parent_dir = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
|
||||
# Import terminal_tool module directly using importlib to avoid tools/__init__.py
|
||||
import importlib.util
|
||||
terminal_tool_path = parent_dir / "tools" / "terminal_tool.py"
|
||||
spec = importlib.util.spec_from_file_location("terminal_tool", terminal_tool_path)
|
||||
terminal_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(terminal_module)
|
||||
|
||||
terminal_tool = terminal_module.terminal_tool
|
||||
check_terminal_requirements = terminal_module.check_terminal_requirements
|
||||
_get_env_config = terminal_module._get_env_config
|
||||
cleanup_vm = terminal_module.cleanup_vm
|
||||
get_active_environments_info = terminal_module.get_active_environments_info
|
||||
|
||||
|
||||
def test_modal_requirements():
|
||||
"""Test that Modal requirements are met."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 1: Modal Requirements Check")
|
||||
print("=" * 60)
|
||||
|
||||
config = _get_env_config()
|
||||
print(f"Current TERMINAL_ENV: {config['env_type']}")
|
||||
print(f"Modal image: {config['modal_image']}")
|
||||
|
||||
# Check for Modal authentication
|
||||
modal_token = os.getenv("MODAL_TOKEN_ID")
|
||||
modal_toml = Path.home() / ".modal.toml"
|
||||
|
||||
print(f"\nModal authentication:")
|
||||
print(f" MODAL_TOKEN_ID env var: {'✅ Set' if modal_token else '❌ Not set'}")
|
||||
print(f" ~/.modal.toml file: {'✅ Exists' if modal_toml.exists() else '❌ Not found'}")
|
||||
|
||||
if config['env_type'] != 'modal':
|
||||
print(f"\n⚠️ TERMINAL_ENV is '{config['env_type']}', not 'modal'")
|
||||
print(" Set TERMINAL_ENV=modal in .env or export it to test Modal backend")
|
||||
return False
|
||||
|
||||
requirements_met = check_terminal_requirements()
|
||||
print(f"\nRequirements check: {'✅ Passed' if requirements_met else '❌ Failed'}")
|
||||
|
||||
return requirements_met
|
||||
|
||||
|
||||
def test_simple_command():
|
||||
"""Test executing a simple command."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 2: Simple Command Execution")
|
||||
print("=" * 60)
|
||||
|
||||
test_task_id = "modal_test_simple"
|
||||
|
||||
print("Executing: echo 'Hello from Modal!'")
|
||||
result = terminal_tool("echo 'Hello from Modal!'", task_id=test_task_id)
|
||||
result_json = json.loads(result)
|
||||
|
||||
print(f"\nResult:")
|
||||
print(f" Output: {result_json.get('output', '')[:200]}")
|
||||
print(f" Exit code: {result_json.get('exit_code')}")
|
||||
print(f" Error: {result_json.get('error')}")
|
||||
|
||||
success = result_json.get('exit_code') == 0 and 'Hello from Modal!' in result_json.get('output', '')
|
||||
print(f"\nTest: {'✅ Passed' if success else '❌ Failed'}")
|
||||
|
||||
# Cleanup
|
||||
cleanup_vm(test_task_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def test_python_execution():
|
||||
"""Test executing Python code in Modal."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 3: Python Execution")
|
||||
print("=" * 60)
|
||||
|
||||
test_task_id = "modal_test_python"
|
||||
|
||||
python_cmd = 'python3 -c "import sys; print(f\'Python {sys.version}\')"'
|
||||
print(f"Executing: {python_cmd}")
|
||||
|
||||
result = terminal_tool(python_cmd, task_id=test_task_id)
|
||||
result_json = json.loads(result)
|
||||
|
||||
print(f"\nResult:")
|
||||
print(f" Output: {result_json.get('output', '')[:200]}")
|
||||
print(f" Exit code: {result_json.get('exit_code')}")
|
||||
print(f" Error: {result_json.get('error')}")
|
||||
|
||||
success = result_json.get('exit_code') == 0 and 'Python' in result_json.get('output', '')
|
||||
print(f"\nTest: {'✅ Passed' if success else '❌ Failed'}")
|
||||
|
||||
# Cleanup
|
||||
cleanup_vm(test_task_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def test_pip_install():
|
||||
"""Test installing a package with pip in Modal."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 4: Pip Install Test")
|
||||
print("=" * 60)
|
||||
|
||||
test_task_id = "modal_test_pip"
|
||||
|
||||
# Install a small package and verify
|
||||
print("Executing: pip install --break-system-packages cowsay && python3 -c \"import cowsay; cowsay.cow('Modal works!')\"")
|
||||
|
||||
result = terminal_tool(
|
||||
"pip install --break-system-packages cowsay && python3 -c \"import cowsay; cowsay.cow('Modal works!')\"",
|
||||
task_id=test_task_id,
|
||||
timeout=120
|
||||
)
|
||||
result_json = json.loads(result)
|
||||
|
||||
print(f"\nResult:")
|
||||
output = result_json.get('output', '')
|
||||
print(f" Output (last 500 chars): ...{output[-500:] if len(output) > 500 else output}")
|
||||
print(f" Exit code: {result_json.get('exit_code')}")
|
||||
print(f" Error: {result_json.get('error')}")
|
||||
|
||||
success = result_json.get('exit_code') == 0 and 'Modal works!' in result_json.get('output', '')
|
||||
print(f"\nTest: {'✅ Passed' if success else '❌ Failed'}")
|
||||
|
||||
# Cleanup
|
||||
cleanup_vm(test_task_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def test_filesystem_persistence():
|
||||
"""Test that filesystem persists between commands in the same task."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 5: Filesystem Persistence")
|
||||
print("=" * 60)
|
||||
|
||||
test_task_id = "modal_test_persist"
|
||||
|
||||
# Create a file
|
||||
print("Step 1: Creating test file...")
|
||||
result1 = terminal_tool("echo 'persistence test' > /tmp/modal_test.txt", task_id=test_task_id)
|
||||
result1_json = json.loads(result1)
|
||||
print(f" Exit code: {result1_json.get('exit_code')}")
|
||||
|
||||
# Read the file back
|
||||
print("Step 2: Reading test file...")
|
||||
result2 = terminal_tool("cat /tmp/modal_test.txt", task_id=test_task_id)
|
||||
result2_json = json.loads(result2)
|
||||
print(f" Output: {result2_json.get('output', '')}")
|
||||
print(f" Exit code: {result2_json.get('exit_code')}")
|
||||
|
||||
success = (
|
||||
result1_json.get('exit_code') == 0 and
|
||||
result2_json.get('exit_code') == 0 and
|
||||
'persistence test' in result2_json.get('output', '')
|
||||
)
|
||||
print(f"\nTest: {'✅ Passed' if success else '❌ Failed'}")
|
||||
|
||||
# Cleanup
|
||||
cleanup_vm(test_task_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def test_environment_isolation():
|
||||
"""Test that different task_ids get isolated environments."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 6: Environment Isolation")
|
||||
print("=" * 60)
|
||||
|
||||
task1 = "modal_test_iso_1"
|
||||
task2 = "modal_test_iso_2"
|
||||
|
||||
# Create file in task1
|
||||
print("Step 1: Creating file in task1...")
|
||||
result1 = terminal_tool("echo 'task1 data' > /tmp/isolated.txt", task_id=task1)
|
||||
|
||||
# Try to read from task2 (should not exist)
|
||||
print("Step 2: Trying to read file from task2 (should not exist)...")
|
||||
result2 = terminal_tool("cat /tmp/isolated.txt 2>&1 || echo 'FILE_NOT_FOUND'", task_id=task2)
|
||||
result2_json = json.loads(result2)
|
||||
|
||||
# The file should either not exist or be empty in task2
|
||||
output = result2_json.get('output', '')
|
||||
isolated = 'task1 data' not in output or 'FILE_NOT_FOUND' in output or 'No such file' in output
|
||||
|
||||
print(f" Task2 output: {output[:200]}")
|
||||
print(f"\nTest: {'✅ Passed (environments isolated)' if isolated else '❌ Failed (environments NOT isolated)'}")
|
||||
|
||||
# Cleanup
|
||||
cleanup_vm(task1)
|
||||
cleanup_vm(task2)
|
||||
|
||||
return isolated
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all Modal terminal tests."""
|
||||
print("🧪 Modal Terminal Tool Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
# Check current config
|
||||
config = _get_env_config()
|
||||
print(f"\nCurrent configuration:")
|
||||
print(f" TERMINAL_ENV: {config['env_type']}")
|
||||
print(f" TERMINAL_MODAL_IMAGE: {config['modal_image']}")
|
||||
print(f" TERMINAL_TIMEOUT: {config['timeout']}s")
|
||||
|
||||
if config['env_type'] != 'modal':
|
||||
print(f"\n⚠️ WARNING: TERMINAL_ENV is set to '{config['env_type']}', not 'modal'")
|
||||
print(" To test Modal specifically, set TERMINAL_ENV=modal")
|
||||
response = input("\n Continue testing with current backend? (y/n): ")
|
||||
if response.lower() != 'y':
|
||||
print("Aborting.")
|
||||
return
|
||||
|
||||
results = {}
|
||||
|
||||
# Run tests
|
||||
results['requirements'] = test_modal_requirements()
|
||||
|
||||
if not results['requirements']:
|
||||
print("\n❌ Requirements not met. Cannot continue with other tests.")
|
||||
return
|
||||
|
||||
results['simple_command'] = test_simple_command()
|
||||
results['python_execution'] = test_python_execution()
|
||||
results['pip_install'] = test_pip_install()
|
||||
results['filesystem_persistence'] = test_filesystem_persistence()
|
||||
results['environment_isolation'] = test_environment_isolation()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for v in results.values() if v)
|
||||
total = len(results)
|
||||
|
||||
for test_name, passed_test in results.items():
|
||||
status = "✅ PASSED" if passed_test else "❌ FAILED"
|
||||
print(f" {test_name}: {status}")
|
||||
|
||||
print(f"\nTotal: {passed}/{total} tests passed")
|
||||
|
||||
# Show active environments
|
||||
env_info = get_active_environments_info()
|
||||
print(f"\nActive environments after tests: {env_info['count']}")
|
||||
if env_info['count'] > 0:
|
||||
print(f" Task IDs: {env_info['task_ids']}")
|
||||
|
||||
return passed == total
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
611
hermes_code/tests/integration/test_voice_channel_flow.py
Normal file
611
hermes_code/tests/integration/test_voice_channel_flow.py
Normal file
|
|
@ -0,0 +1,611 @@
|
|||
"""Integration tests for Discord voice channel audio flow.
|
||||
|
||||
Uses real NaCl encryption and Opus codec (no mocks for crypto/codec).
|
||||
Does NOT require a Discord connection — tests the VoiceReceiver
|
||||
packet processing pipeline end-to-end.
|
||||
|
||||
Requires: PyNaCl>=1.5.0, discord.py[voice] (opus codec)
|
||||
"""
|
||||
|
||||
import struct
|
||||
import time
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
# Skip entire module if voice deps are missing
|
||||
pytest.importorskip("nacl.secret", reason="PyNaCl required for voice integration tests")
|
||||
discord = pytest.importorskip("discord", reason="discord.py required for voice integration tests")
|
||||
|
||||
import nacl.secret
|
||||
|
||||
try:
|
||||
if not discord.opus.is_loaded():
|
||||
import ctypes.util
|
||||
opus_path = ctypes.util.find_library("opus")
|
||||
if not opus_path:
|
||||
import sys
|
||||
for p in ("/opt/homebrew/lib/libopus.dylib", "/usr/local/lib/libopus.dylib"):
|
||||
import os
|
||||
if os.path.isfile(p):
|
||||
opus_path = p
|
||||
break
|
||||
if opus_path:
|
||||
discord.opus.load_opus(opus_path)
|
||||
OPUS_AVAILABLE = discord.opus.is_loaded()
|
||||
except Exception:
|
||||
OPUS_AVAILABLE = False
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_secret_key():
|
||||
"""Generate a random 32-byte key."""
|
||||
import os
|
||||
return os.urandom(32)
|
||||
|
||||
|
||||
def _build_encrypted_rtp_packet(secret_key, opus_payload, ssrc=100, seq=1, timestamp=960):
|
||||
"""Build a real NaCl-encrypted RTP packet matching Discord's format.
|
||||
|
||||
Format: RTP header (12 bytes) + encrypted(opus) + 4-byte nonce
|
||||
Encryption: aead_xchacha20_poly1305 with RTP header as AAD.
|
||||
"""
|
||||
# RTP header: version=2, payload_type=0x78, no extension, no CSRC
|
||||
header = struct.pack(">BBHII", 0x80, 0x78, seq, timestamp, ssrc)
|
||||
|
||||
# Encrypt with NaCl AEAD
|
||||
box = nacl.secret.Aead(secret_key)
|
||||
nonce_counter = struct.pack(">I", seq) # 4-byte counter as nonce seed
|
||||
# Full 24-byte nonce: counter in first 4 bytes, rest zeros
|
||||
full_nonce = nonce_counter + b'\x00' * 20
|
||||
|
||||
enc_msg = box.encrypt(opus_payload, header, full_nonce)
|
||||
ciphertext = enc_msg.ciphertext # without nonce prefix
|
||||
|
||||
# Discord format: header + ciphertext + 4-byte nonce
|
||||
return header + ciphertext + nonce_counter
|
||||
|
||||
|
||||
def _make_voice_receiver(secret_key, dave_session=None, bot_ssrc=9999,
|
||||
allowed_user_ids=None, members=None):
|
||||
"""Create a VoiceReceiver with real secret key."""
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = list(secret_key)
|
||||
vc._connection.dave_session = dave_session
|
||||
vc._connection.ssrc = bot_ssrc
|
||||
vc._connection.add_socket_listener = MagicMock()
|
||||
vc._connection.remove_socket_listener = MagicMock()
|
||||
vc._connection.hook = None
|
||||
vc.user = SimpleNamespace(id=bot_ssrc)
|
||||
vc.channel = MagicMock()
|
||||
vc.channel.members = members or []
|
||||
receiver = VoiceReceiver(vc, allowed_user_ids=allowed_user_ids)
|
||||
receiver.start()
|
||||
return receiver
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRealNaClDecrypt:
|
||||
"""End-to-end: real NaCl encrypt → _on_packet decrypt → buffer."""
|
||||
|
||||
def test_valid_encrypted_packet_buffered(self):
|
||||
"""Real NaCl encrypted packet → decrypted → buffered."""
|
||||
key = _make_secret_key()
|
||||
opus_silence = b'\xf8\xff\xfe'
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, opus_silence, ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_wrong_key_packet_dropped(self):
|
||||
"""Packet encrypted with wrong key → NaCl fails → not buffered."""
|
||||
real_key = _make_secret_key()
|
||||
wrong_key = _make_secret_key()
|
||||
opus_silence = b'\xf8\xff\xfe'
|
||||
receiver = _make_voice_receiver(real_key)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(wrong_key, opus_silence, ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
def test_bot_ssrc_ignored(self):
|
||||
"""Packet from bot's own SSRC → ignored."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key, bot_ssrc=9999)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=9999)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_multiple_packets_accumulate(self):
|
||||
"""Multiple valid packets → buffer grows."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
for seq in range(1, 6):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
buf_size = len(receiver._buffers[100])
|
||||
assert buf_size > 0, "Multiple packets should accumulate in buffer"
|
||||
|
||||
def test_different_ssrcs_separate_buffers(self):
|
||||
"""Packets from different SSRCs → separate buffers."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
for ssrc in [100, 200, 300]:
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=ssrc)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers) == 3
|
||||
for ssrc in [100, 200, 300]:
|
||||
assert ssrc in receiver._buffers
|
||||
|
||||
|
||||
class TestRealNaClWithDAVE:
|
||||
"""NaCl decrypt + DAVE passthrough scenarios with real crypto."""
|
||||
|
||||
def test_dave_unknown_ssrc_passthrough(self):
|
||||
"""DAVE enabled but SSRC unknown → skip DAVE, buffer audio."""
|
||||
key = _make_secret_key()
|
||||
dave = MagicMock() # DAVE session present but SSRC not mapped
|
||||
receiver = _make_voice_receiver(key, dave_session=dave)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
# DAVE decrypt not called (SSRC unknown)
|
||||
dave.decrypt.assert_not_called()
|
||||
# Audio still buffered via passthrough
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_dave_unencrypted_error_passthrough(self):
|
||||
"""DAVE raises 'Unencrypted' → use NaCl-decrypted data as-is."""
|
||||
key = _make_secret_key()
|
||||
dave = MagicMock()
|
||||
dave.decrypt.side_effect = Exception(
|
||||
"DecryptionFailed(UnencryptedWhenPassthroughDisabled)"
|
||||
)
|
||||
receiver = _make_voice_receiver(key, dave_session=dave)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
# DAVE was called but failed → passthrough
|
||||
dave.decrypt.assert_called_once()
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_dave_real_error_drops(self):
|
||||
"""DAVE raises non-Unencrypted error → packet dropped."""
|
||||
key = _make_secret_key()
|
||||
dave = MagicMock()
|
||||
dave.decrypt.side_effect = Exception("KeyRotationFailed")
|
||||
receiver = _make_voice_receiver(key, dave_session=dave)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
|
||||
class TestFullVoiceFlow:
|
||||
"""End-to-end: encrypt → receive → buffer → silence detect → complete."""
|
||||
|
||||
def test_single_utterance_flow(self):
|
||||
"""Encrypt packets → buffer → silence → check_silence returns utterance."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
# Send enough packets to exceed MIN_SPEECH_DURATION (0.5s)
|
||||
# At 48kHz stereo 16-bit, each Opus silence frame decodes to ~3840 bytes
|
||||
# Need 96000 bytes = ~25 frames
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
# Simulate silence by setting last_packet_time in the past
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
user_id, pcm_data = completed[0]
|
||||
assert user_id == 42
|
||||
assert len(pcm_data) > 0
|
||||
|
||||
def test_utterance_with_ssrc_automap(self):
|
||||
"""No SPEAKING event → auto-map sole allowed user → utterance processed."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids={"42"}, members=members
|
||||
)
|
||||
# No map_ssrc call — simulating missing SPEAKING event
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42 # auto-mapped to sole allowed user
|
||||
|
||||
def test_pause_blocks_during_playback(self):
|
||||
"""Pause receiver → packets ignored → resume → packets accepted."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
# Pause (echo prevention during TTS playback)
|
||||
receiver.pause()
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
# Resume
|
||||
receiver.resume()
|
||||
receiver._on_packet(packet)
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_corrupted_packet_ignored(self):
|
||||
"""Corrupted/truncated packet → silently ignored."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
# Too short
|
||||
receiver._on_packet(b"\x00" * 5)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
# Wrong RTP version
|
||||
bad_header = struct.pack(">BBHII", 0x00, 0x78, 1, 960, 100)
|
||||
receiver._on_packet(bad_header + b"\x00" * 20)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
# Wrong payload type
|
||||
bad_pt = struct.pack(">BBHII", 0x80, 0x00, 1, 960, 100)
|
||||
receiver._on_packet(bad_pt + b"\x00" * 20)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_stop_cleans_everything(self):
|
||||
"""stop() clears all state cleanly."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
receiver.stop()
|
||||
assert receiver._running is False
|
||||
assert len(receiver._buffers) == 0
|
||||
assert len(receiver._ssrc_to_user) == 0
|
||||
assert len(receiver._decoders) == 0
|
||||
|
||||
|
||||
class TestSPEAKINGHook:
|
||||
"""SPEAKING event hook correctly maps SSRC to user_id."""
|
||||
|
||||
def test_speaking_hook_installed(self):
|
||||
"""start() installs speaking hook on connection."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
conn = receiver._vc._connection
|
||||
# hook should be set (wrapped)
|
||||
assert conn.hook is not None
|
||||
|
||||
def test_map_ssrc_via_speaking(self):
|
||||
"""SPEAKING op 5 event maps SSRC to user_id."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(500, 12345)
|
||||
assert receiver._ssrc_to_user[500] == 12345
|
||||
|
||||
def test_map_ssrc_overwrites(self):
|
||||
"""New SPEAKING event for same SSRC overwrites old mapping."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(500, 111)
|
||||
receiver.map_ssrc(500, 222)
|
||||
assert receiver._ssrc_to_user[500] == 222
|
||||
|
||||
def test_speaking_mapped_audio_processed(self):
|
||||
"""After SSRC is mapped, audio from that SSRC gets correct user_id."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
|
||||
class TestAuthFiltering:
|
||||
"""Only allowed users' audio should be processed."""
|
||||
|
||||
def test_allowed_user_audio_processed(self):
|
||||
"""Allowed user's utterance is returned by check_silence."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids={"42"}, members=members,
|
||||
)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
def test_automap_rejects_unallowed_user(self):
|
||||
"""Auto-map refuses to map SSRC to user not in allowed list."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids={"99"}, # Alice not allowed
|
||||
members=members,
|
||||
)
|
||||
# No map_ssrc — SSRC unknown, auto-map should reject
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_empty_allowlist_allows_all(self):
|
||||
"""Empty allowed_user_ids means no restriction."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids=None, members=members,
|
||||
)
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
# Auto-mapped to sole non-bot member
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
|
||||
class TestRejoinFlow:
|
||||
"""Leave and rejoin: state cleanup and fresh receiver."""
|
||||
|
||||
def test_stop_then_new_receiver_clean_state(self):
|
||||
"""After stop(), a new receiver starts with empty state."""
|
||||
key = _make_secret_key()
|
||||
receiver1 = _make_voice_receiver(key)
|
||||
receiver1.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver1._on_packet(packet)
|
||||
|
||||
assert len(receiver1._buffers[100]) > 0
|
||||
receiver1.stop()
|
||||
|
||||
# New receiver (simulates rejoin)
|
||||
receiver2 = _make_voice_receiver(key)
|
||||
assert len(receiver2._buffers) == 0
|
||||
assert len(receiver2._ssrc_to_user) == 0
|
||||
assert len(receiver2._decoders) == 0
|
||||
|
||||
def test_rejoin_new_ssrc_works(self):
|
||||
"""After rejoin, user may get new SSRC — still works."""
|
||||
key = _make_secret_key()
|
||||
receiver1 = _make_voice_receiver(key)
|
||||
receiver1.map_ssrc(100, 42) # old SSRC
|
||||
receiver1.stop()
|
||||
|
||||
receiver2 = _make_voice_receiver(key)
|
||||
receiver2.map_ssrc(200, 42) # new SSRC after rejoin
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=200, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver2._on_packet(packet)
|
||||
|
||||
receiver2._last_packet_time[200] = time.monotonic() - 3.0
|
||||
completed = receiver2.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
def test_rejoin_without_speaking_event_automap(self):
|
||||
"""Rejoin without SPEAKING event — auto-map sole allowed user."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
|
||||
# First session
|
||||
receiver1 = _make_voice_receiver(
|
||||
key, allowed_user_ids={"42"}, members=members,
|
||||
)
|
||||
receiver1.stop()
|
||||
|
||||
# Rejoin — new key (Discord may assign new secret_key)
|
||||
new_key = _make_secret_key()
|
||||
receiver2 = _make_voice_receiver(
|
||||
new_key, allowed_user_ids={"42"}, members=members,
|
||||
)
|
||||
# No map_ssrc — simulating missing SPEAKING event
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
new_key, b'\xf8\xff\xfe', ssrc=300, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver2._on_packet(packet)
|
||||
|
||||
receiver2._last_packet_time[300] = time.monotonic() - 3.0
|
||||
completed = receiver2.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
|
||||
class TestMultiGuildIsolation:
|
||||
"""Each guild has independent voice state."""
|
||||
|
||||
def test_separate_receivers_independent(self):
|
||||
"""Two receivers (different guilds) don't interfere."""
|
||||
key1 = _make_secret_key()
|
||||
key2 = _make_secret_key()
|
||||
|
||||
receiver1 = _make_voice_receiver(key1, bot_ssrc=1111)
|
||||
receiver2 = _make_voice_receiver(key2, bot_ssrc=2222)
|
||||
|
||||
receiver1.map_ssrc(100, 42)
|
||||
receiver2.map_ssrc(200, 99)
|
||||
|
||||
# Send to receiver1
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key1, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver1._on_packet(packet)
|
||||
|
||||
# receiver2 should be empty
|
||||
assert len(receiver2._buffers) == 0
|
||||
assert 100 in receiver1._buffers
|
||||
|
||||
def test_stop_one_doesnt_affect_other(self):
|
||||
"""Stopping one receiver doesn't affect another."""
|
||||
key1 = _make_secret_key()
|
||||
key2 = _make_secret_key()
|
||||
|
||||
receiver1 = _make_voice_receiver(key1)
|
||||
receiver2 = _make_voice_receiver(key2)
|
||||
|
||||
receiver1.map_ssrc(100, 42)
|
||||
receiver2.map_ssrc(200, 99)
|
||||
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key2, b'\xf8\xff\xfe', ssrc=200, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver2._on_packet(packet)
|
||||
|
||||
receiver1.stop()
|
||||
|
||||
# receiver2 still has data
|
||||
assert receiver2._running is True
|
||||
assert len(receiver2._buffers[200]) > 0
|
||||
|
||||
|
||||
class TestEchoPreventionFlow:
|
||||
"""Receiver pause/resume during TTS playback prevents echo."""
|
||||
|
||||
def test_audio_during_pause_ignored(self):
|
||||
"""Audio arriving while paused is completely ignored."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
receiver.pause()
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
def test_audio_after_resume_processed(self):
|
||||
"""Audio arriving after resume is processed normally."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
# Pause → send packets → resume → send more packets
|
||||
receiver.pause()
|
||||
for seq in range(1, 5):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
receiver.resume()
|
||||
for seq in range(5, 35):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
628
hermes_code/tests/integration/test_web_tools.py
Normal file
628
hermes_code/tests/integration/test_web_tools.py
Normal file
|
|
@ -0,0 +1,628 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Test Suite for Web Tools Module
|
||||
|
||||
This script tests all web tools functionality to ensure they work correctly.
|
||||
Run this after any updates to the web_tools.py module or backend libraries.
|
||||
|
||||
Usage:
|
||||
python test_web_tools.py # Run all tests
|
||||
python test_web_tools.py --no-llm # Skip LLM processing tests
|
||||
python test_web_tools.py --verbose # Show detailed output
|
||||
|
||||
Requirements:
|
||||
- PARALLEL_API_KEY or FIRECRAWL_API_KEY environment variable must be set
|
||||
- An auxiliary LLM provider (OPENROUTER_API_KEY or Nous Portal auth) (optional, for LLM tests)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
# Import the web tools to test (updated path after moving tools/)
|
||||
from tools.web_tools import (
|
||||
web_search_tool,
|
||||
web_extract_tool,
|
||||
web_crawl_tool,
|
||||
check_firecrawl_api_key,
|
||||
check_web_api_key,
|
||||
check_auxiliary_model,
|
||||
get_debug_session_info,
|
||||
_get_backend,
|
||||
)
|
||||
|
||||
|
||||
class Colors:
|
||||
"""ANSI color codes for terminal output"""
|
||||
HEADER = '\033[95m'
|
||||
BLUE = '\033[94m'
|
||||
CYAN = '\033[96m'
|
||||
GREEN = '\033[92m'
|
||||
WARNING = '\033[93m'
|
||||
FAIL = '\033[91m'
|
||||
ENDC = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
UNDERLINE = '\033[4m'
|
||||
|
||||
|
||||
def print_header(text: str):
|
||||
"""Print a formatted header"""
|
||||
print(f"\n{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}")
|
||||
print(f"{Colors.HEADER}{Colors.BOLD}{text}{Colors.ENDC}")
|
||||
print(f"{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}")
|
||||
|
||||
|
||||
def print_section(text: str):
|
||||
"""Print a formatted section header"""
|
||||
print(f"\n{Colors.CYAN}{Colors.BOLD}📌 {text}{Colors.ENDC}")
|
||||
print(f"{Colors.CYAN}{'-'*50}{Colors.ENDC}")
|
||||
|
||||
|
||||
def print_success(text: str):
|
||||
"""Print success message"""
|
||||
print(f"{Colors.GREEN}✅ {text}{Colors.ENDC}")
|
||||
|
||||
|
||||
def print_error(text: str):
|
||||
"""Print error message"""
|
||||
print(f"{Colors.FAIL}❌ {text}{Colors.ENDC}")
|
||||
|
||||
|
||||
def print_warning(text: str):
|
||||
"""Print warning message"""
|
||||
print(f"{Colors.WARNING}⚠️ {text}{Colors.ENDC}")
|
||||
|
||||
|
||||
def print_info(text: str, indent: int = 0):
|
||||
"""Print info message"""
|
||||
indent_str = " " * indent
|
||||
print(f"{indent_str}{Colors.BLUE}ℹ️ {text}{Colors.ENDC}")
|
||||
|
||||
|
||||
class WebToolsTester:
|
||||
"""Test suite for web tools"""
|
||||
|
||||
def __init__(self, verbose: bool = False, test_llm: bool = True):
|
||||
self.verbose = verbose
|
||||
self.test_llm = test_llm
|
||||
self.test_results = {
|
||||
"passed": [],
|
||||
"failed": [],
|
||||
"skipped": []
|
||||
}
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
|
||||
def log_result(self, test_name: str, status: str, details: str = ""):
|
||||
"""Log test result"""
|
||||
result = {
|
||||
"test": test_name,
|
||||
"status": status,
|
||||
"details": details,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if status == "passed":
|
||||
self.test_results["passed"].append(result)
|
||||
print_success(f"{test_name}: {details}" if details else test_name)
|
||||
elif status == "failed":
|
||||
self.test_results["failed"].append(result)
|
||||
print_error(f"{test_name}: {details}" if details else test_name)
|
||||
elif status == "skipped":
|
||||
self.test_results["skipped"].append(result)
|
||||
print_warning(f"{test_name} skipped: {details}" if details else f"{test_name} skipped")
|
||||
|
||||
def test_environment(self) -> bool:
|
||||
"""Test environment setup and API keys"""
|
||||
print_section("Environment Check")
|
||||
|
||||
# Check web backend API key (Parallel or Firecrawl)
|
||||
if not check_web_api_key():
|
||||
self.log_result("Web Backend API Key", "failed", "PARALLEL_API_KEY or FIRECRAWL_API_KEY not set")
|
||||
return False
|
||||
else:
|
||||
backend = _get_backend()
|
||||
self.log_result("Web Backend API Key", "passed", f"Using {backend} backend")
|
||||
|
||||
# Check auxiliary LLM provider (optional)
|
||||
if not check_auxiliary_model():
|
||||
self.log_result("Auxiliary LLM", "skipped", "No auxiliary LLM provider available (LLM tests will be skipped)")
|
||||
self.test_llm = False
|
||||
else:
|
||||
self.log_result("Auxiliary LLM", "passed", "Found")
|
||||
|
||||
# Check debug mode
|
||||
debug_info = get_debug_session_info()
|
||||
if debug_info["enabled"]:
|
||||
print_info(f"Debug mode enabled - Session: {debug_info['session_id']}")
|
||||
print_info(f"Debug log: {debug_info['log_path']}")
|
||||
|
||||
return True
|
||||
|
||||
def test_web_search(self) -> List[str]:
|
||||
"""Test web search functionality"""
|
||||
print_section("Test 1: Web Search")
|
||||
|
||||
test_queries = [
|
||||
("Python web scraping tutorial", 5),
|
||||
("Firecrawl API documentation", 3),
|
||||
("inflammatory arthritis symptoms treatment", 8) # Test medical query from your example
|
||||
]
|
||||
|
||||
extracted_urls = []
|
||||
|
||||
for query, limit in test_queries:
|
||||
try:
|
||||
print(f"\n Testing search: '{query}' (limit={limit})")
|
||||
|
||||
if self.verbose:
|
||||
print(f" Calling web_search_tool(query='{query}', limit={limit})")
|
||||
|
||||
# Perform search
|
||||
result = web_search_tool(query, limit)
|
||||
|
||||
# Parse result
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except json.JSONDecodeError as e:
|
||||
self.log_result(f"Search: {query[:30]}...", "failed", f"Invalid JSON: {e}")
|
||||
if self.verbose:
|
||||
print(f" Raw response (first 500 chars): {result[:500]}...")
|
||||
continue
|
||||
|
||||
if "error" in data:
|
||||
self.log_result(f"Search: {query[:30]}...", "failed", f"API error: {data['error']}")
|
||||
continue
|
||||
|
||||
# Check structure
|
||||
if "success" not in data or "data" not in data:
|
||||
self.log_result(f"Search: {query[:30]}...", "failed", "Missing success or data fields")
|
||||
if self.verbose:
|
||||
print(f" Response keys: {list(data.keys())}")
|
||||
continue
|
||||
|
||||
web_results = data.get("data", {}).get("web", [])
|
||||
|
||||
if not web_results:
|
||||
self.log_result(f"Search: {query[:30]}...", "failed", "Empty web results array")
|
||||
if self.verbose:
|
||||
print(f" data.web content: {data.get('data', {}).get('web')}")
|
||||
continue
|
||||
|
||||
# Validate each result
|
||||
valid_results = 0
|
||||
missing_fields = []
|
||||
|
||||
for i, result in enumerate(web_results):
|
||||
required_fields = ["url", "title", "description"]
|
||||
has_all_fields = all(key in result for key in required_fields)
|
||||
|
||||
if has_all_fields:
|
||||
valid_results += 1
|
||||
# Collect URLs for extraction test
|
||||
if len(extracted_urls) < 3:
|
||||
extracted_urls.append(result["url"])
|
||||
|
||||
if self.verbose:
|
||||
print(f" Result {i+1}: ✓ {result['title'][:50]}...")
|
||||
print(f" URL: {result['url'][:60]}...")
|
||||
else:
|
||||
missing = [f for f in required_fields if f not in result]
|
||||
missing_fields.append(f"Result {i+1} missing: {missing}")
|
||||
if self.verbose:
|
||||
print(f" Result {i+1}: ✗ Missing fields: {missing}")
|
||||
|
||||
# Log results
|
||||
if valid_results == len(web_results):
|
||||
self.log_result(
|
||||
f"Search: {query[:30]}...",
|
||||
"passed",
|
||||
f"All {valid_results} results valid"
|
||||
)
|
||||
else:
|
||||
self.log_result(
|
||||
f"Search: {query[:30]}...",
|
||||
"failed",
|
||||
f"Only {valid_results}/{len(web_results)} valid. Issues: {'; '.join(missing_fields[:3])}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.log_result(f"Search: {query[:30]}...", "failed", f"Exception: {type(e).__name__}: {str(e)}")
|
||||
if self.verbose:
|
||||
import traceback
|
||||
print(f" Traceback: {traceback.format_exc()}")
|
||||
|
||||
if self.verbose and extracted_urls:
|
||||
print(f"\n URLs collected for extraction test: {len(extracted_urls)}")
|
||||
for url in extracted_urls:
|
||||
print(f" - {url}")
|
||||
|
||||
return extracted_urls
|
||||
|
||||
async def test_web_extract(self, urls: List[str] = None):
|
||||
"""Test web content extraction"""
|
||||
print_section("Test 2: Web Extract (without LLM)")
|
||||
|
||||
# Use provided URLs or defaults
|
||||
if not urls:
|
||||
urls = [
|
||||
"https://docs.firecrawl.dev/introduction",
|
||||
"https://www.python.org/about/"
|
||||
]
|
||||
print(f" Using default URLs for testing")
|
||||
else:
|
||||
print(f" Using {len(urls)} URLs from search results")
|
||||
|
||||
# Test extraction
|
||||
if urls:
|
||||
try:
|
||||
test_urls = urls[:2] # Test with max 2 URLs
|
||||
print(f"\n Extracting content from {len(test_urls)} URL(s)...")
|
||||
for url in test_urls:
|
||||
print(f" - {url}")
|
||||
|
||||
if self.verbose:
|
||||
print(f" Calling web_extract_tool(urls={test_urls}, format='markdown', use_llm_processing=False)")
|
||||
|
||||
result = await web_extract_tool(
|
||||
test_urls,
|
||||
format="markdown",
|
||||
use_llm_processing=False
|
||||
)
|
||||
|
||||
# Parse result
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except json.JSONDecodeError as e:
|
||||
self.log_result("Extract (no LLM)", "failed", f"Invalid JSON: {e}")
|
||||
if self.verbose:
|
||||
print(f" Raw response (first 500 chars): {result[:500]}...")
|
||||
return
|
||||
|
||||
if "error" in data:
|
||||
self.log_result("Extract (no LLM)", "failed", f"API error: {data['error']}")
|
||||
return
|
||||
|
||||
results = data.get("results", [])
|
||||
|
||||
if not results:
|
||||
self.log_result("Extract (no LLM)", "failed", "No results in response")
|
||||
if self.verbose:
|
||||
print(f" Response keys: {list(data.keys())}")
|
||||
return
|
||||
|
||||
# Validate each result
|
||||
valid_results = 0
|
||||
failed_results = 0
|
||||
total_content_length = 0
|
||||
extraction_details = []
|
||||
|
||||
for i, result in enumerate(results):
|
||||
title = result.get("title", "No title")
|
||||
content = result.get("content", "")
|
||||
error = result.get("error")
|
||||
|
||||
if error:
|
||||
failed_results += 1
|
||||
extraction_details.append(f"Page {i+1}: ERROR - {error}")
|
||||
if self.verbose:
|
||||
print(f" Page {i+1}: ✗ Error - {error}")
|
||||
elif content:
|
||||
content_len = len(content)
|
||||
total_content_length += content_len
|
||||
valid_results += 1
|
||||
extraction_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)")
|
||||
if self.verbose:
|
||||
print(f" Page {i+1}: ✓ {title[:50]}... - {content_len} characters")
|
||||
print(f" First 100 chars: {content[:100]}...")
|
||||
else:
|
||||
extraction_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)")
|
||||
if self.verbose:
|
||||
print(f" Page {i+1}: ⚠ {title[:50]}... - Empty content")
|
||||
|
||||
# Log results
|
||||
if valid_results > 0:
|
||||
self.log_result(
|
||||
"Extract (no LLM)",
|
||||
"passed",
|
||||
f"{valid_results}/{len(results)} pages extracted, {total_content_length} total chars"
|
||||
)
|
||||
else:
|
||||
self.log_result(
|
||||
"Extract (no LLM)",
|
||||
"failed",
|
||||
f"No valid content. {failed_results} errors, {len(results) - failed_results} empty"
|
||||
)
|
||||
if self.verbose:
|
||||
print(f"\n Extraction details:")
|
||||
for detail in extraction_details:
|
||||
print(f" {detail}")
|
||||
|
||||
except Exception as e:
|
||||
self.log_result("Extract (no LLM)", "failed", f"Exception: {type(e).__name__}: {str(e)}")
|
||||
if self.verbose:
|
||||
import traceback
|
||||
print(f" Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def test_web_extract_with_llm(self, urls: List[str] = None):
|
||||
"""Test web extraction with LLM processing"""
|
||||
print_section("Test 3: Web Extract (with Gemini LLM)")
|
||||
|
||||
if not self.test_llm:
|
||||
self.log_result("Extract (with LLM)", "skipped", "LLM testing disabled")
|
||||
return
|
||||
|
||||
# Use a URL likely to have substantial content
|
||||
test_url = urls[0] if urls else "https://docs.firecrawl.dev/features/scrape"
|
||||
|
||||
try:
|
||||
print(f"\n Extracting and processing: {test_url}")
|
||||
|
||||
result = await web_extract_tool(
|
||||
[test_url],
|
||||
format="markdown",
|
||||
use_llm_processing=True,
|
||||
min_length=1000 # Lower threshold for testing
|
||||
)
|
||||
|
||||
data = json.loads(result)
|
||||
|
||||
if "error" in data:
|
||||
self.log_result("Extract (with LLM)", "failed", data["error"])
|
||||
return
|
||||
|
||||
results = data.get("results", [])
|
||||
|
||||
if not results:
|
||||
self.log_result("Extract (with LLM)", "failed", "No results returned")
|
||||
return
|
||||
|
||||
result = results[0]
|
||||
content = result.get("content", "")
|
||||
|
||||
if content:
|
||||
content_len = len(content)
|
||||
|
||||
# Check if content was actually processed (should be shorter than typical raw content)
|
||||
if content_len > 0:
|
||||
self.log_result(
|
||||
"Extract (with LLM)",
|
||||
"passed",
|
||||
f"Content processed: {content_len} chars"
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print(f"\n First 300 chars of processed content:")
|
||||
print(f" {content[:300]}...")
|
||||
else:
|
||||
self.log_result("Extract (with LLM)", "failed", "No content after processing")
|
||||
else:
|
||||
self.log_result("Extract (with LLM)", "failed", "No content field in result")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
self.log_result("Extract (with LLM)", "failed", f"Invalid JSON: {e}")
|
||||
except Exception as e:
|
||||
self.log_result("Extract (with LLM)", "failed", str(e))
|
||||
|
||||
async def test_web_crawl(self):
|
||||
"""Test web crawling functionality"""
|
||||
print_section("Test 4: Web Crawl")
|
||||
|
||||
test_sites = [
|
||||
("https://docs.firecrawl.dev", None, 2), # Test docs site
|
||||
("https://firecrawl.dev", None, 3), # Test main site
|
||||
]
|
||||
|
||||
for url, instructions, expected_min_pages in test_sites:
|
||||
try:
|
||||
print(f"\n Testing crawl of: {url}")
|
||||
if instructions:
|
||||
print(f" Instructions: {instructions}")
|
||||
else:
|
||||
print(f" No instructions (general crawl)")
|
||||
print(f" Expected minimum pages: {expected_min_pages}")
|
||||
|
||||
# Show what's being called
|
||||
if self.verbose:
|
||||
print(f" Calling web_crawl_tool(url='{url}', instructions={instructions}, use_llm_processing=False)")
|
||||
|
||||
result = await web_crawl_tool(
|
||||
url,
|
||||
instructions=instructions,
|
||||
use_llm_processing=False # Disable LLM for faster testing
|
||||
)
|
||||
|
||||
# Check if result is valid JSON
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except json.JSONDecodeError as e:
|
||||
self.log_result(f"Crawl: {url}", "failed", f"Invalid JSON response: {e}")
|
||||
if self.verbose:
|
||||
print(f" Raw response (first 500 chars): {result[:500]}...")
|
||||
continue
|
||||
|
||||
# Check for errors
|
||||
if "error" in data:
|
||||
self.log_result(f"Crawl: {url}", "failed", f"API error: {data['error']}")
|
||||
continue
|
||||
|
||||
# Get results
|
||||
results = data.get("results", [])
|
||||
|
||||
if not results:
|
||||
self.log_result(f"Crawl: {url}", "failed", "No pages in results array")
|
||||
if self.verbose:
|
||||
print(f" Full response: {json.dumps(data, indent=2)[:1000]}...")
|
||||
continue
|
||||
|
||||
# Analyze pages
|
||||
valid_pages = 0
|
||||
empty_pages = 0
|
||||
total_content = 0
|
||||
page_details = []
|
||||
|
||||
for i, page in enumerate(results):
|
||||
content = page.get("content", "")
|
||||
title = page.get("title", "Untitled")
|
||||
error = page.get("error")
|
||||
|
||||
if error:
|
||||
page_details.append(f"Page {i+1}: ERROR - {error}")
|
||||
elif content:
|
||||
valid_pages += 1
|
||||
content_len = len(content)
|
||||
total_content += content_len
|
||||
page_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)")
|
||||
else:
|
||||
empty_pages += 1
|
||||
page_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)")
|
||||
|
||||
# Show detailed results if verbose
|
||||
if self.verbose:
|
||||
print(f"\n Crawl Results:")
|
||||
print(f" Total pages returned: {len(results)}")
|
||||
print(f" Valid pages (with content): {valid_pages}")
|
||||
print(f" Empty pages: {empty_pages}")
|
||||
print(f" Total content size: {total_content} characters")
|
||||
print(f"\n Page Details:")
|
||||
for detail in page_details[:10]: # Show first 10 pages
|
||||
print(f" - {detail}")
|
||||
if len(page_details) > 10:
|
||||
print(f" ... and {len(page_details) - 10} more pages")
|
||||
|
||||
# Determine pass/fail
|
||||
if valid_pages >= expected_min_pages:
|
||||
self.log_result(
|
||||
f"Crawl: {url}",
|
||||
"passed",
|
||||
f"{valid_pages}/{len(results)} valid pages, {total_content} chars total"
|
||||
)
|
||||
else:
|
||||
self.log_result(
|
||||
f"Crawl: {url}",
|
||||
"failed",
|
||||
f"Only {valid_pages} valid pages (expected >= {expected_min_pages}), {empty_pages} empty, {len(results)} total"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.log_result(f"Crawl: {url}", "failed", f"Exception: {type(e).__name__}: {str(e)}")
|
||||
if self.verbose:
|
||||
import traceback
|
||||
print(f" Traceback:")
|
||||
print(" " + "\n ".join(traceback.format_exc().split("\n")))
|
||||
|
||||
async def run_all_tests(self):
|
||||
"""Run all tests"""
|
||||
self.start_time = datetime.now()
|
||||
|
||||
print_header("WEB TOOLS TEST SUITE")
|
||||
print(f"Started at: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Test environment
|
||||
if not self.test_environment():
|
||||
print_error("\nCannot proceed without required API keys!")
|
||||
return False
|
||||
|
||||
# Test search and collect URLs
|
||||
urls = self.test_web_search()
|
||||
|
||||
# Test extraction
|
||||
await self.test_web_extract(urls if urls else None)
|
||||
|
||||
# Test extraction with LLM
|
||||
if self.test_llm:
|
||||
await self.test_web_extract_with_llm(urls if urls else None)
|
||||
|
||||
# Test crawling
|
||||
await self.test_web_crawl()
|
||||
|
||||
# Print summary
|
||||
self.end_time = datetime.now()
|
||||
duration = (self.end_time - self.start_time).total_seconds()
|
||||
|
||||
print_header("TEST SUMMARY")
|
||||
print(f"Duration: {duration:.2f} seconds")
|
||||
print(f"\n{Colors.GREEN}Passed: {len(self.test_results['passed'])}{Colors.ENDC}")
|
||||
print(f"{Colors.FAIL}Failed: {len(self.test_results['failed'])}{Colors.ENDC}")
|
||||
print(f"{Colors.WARNING}Skipped: {len(self.test_results['skipped'])}{Colors.ENDC}")
|
||||
|
||||
# List failed tests
|
||||
if self.test_results["failed"]:
|
||||
print(f"\n{Colors.FAIL}{Colors.BOLD}Failed Tests:{Colors.ENDC}")
|
||||
for test in self.test_results["failed"]:
|
||||
print(f" - {test['test']}: {test['details']}")
|
||||
|
||||
# Save results to file
|
||||
self.save_results()
|
||||
|
||||
return len(self.test_results["failed"]) == 0
|
||||
|
||||
def save_results(self):
|
||||
"""Save test results to a JSON file"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"test_results_web_tools_{timestamp}.json"
|
||||
|
||||
results = {
|
||||
"test_suite": "Web Tools",
|
||||
"start_time": self.start_time.isoformat() if self.start_time else None,
|
||||
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||
"duration_seconds": (self.end_time - self.start_time).total_seconds() if self.start_time and self.end_time else None,
|
||||
"summary": {
|
||||
"passed": len(self.test_results["passed"]),
|
||||
"failed": len(self.test_results["failed"]),
|
||||
"skipped": len(self.test_results["skipped"])
|
||||
},
|
||||
"results": self.test_results,
|
||||
"environment": {
|
||||
"web_backend": _get_backend() if check_web_api_key() else None,
|
||||
"firecrawl_api_key": check_firecrawl_api_key(),
|
||||
"parallel_api_key": bool(os.getenv("PARALLEL_API_KEY")),
|
||||
"auxiliary_model": check_auxiliary_model(),
|
||||
"debug_mode": get_debug_session_info()["enabled"]
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print_info(f"Test results saved to: {filename}")
|
||||
except Exception as e:
|
||||
print_warning(f"Failed to save results: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point"""
|
||||
parser = argparse.ArgumentParser(description="Test Web Tools Module")
|
||||
parser.add_argument("--no-llm", action="store_true", help="Skip LLM processing tests")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Show detailed output")
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug mode for web tools")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set debug mode if requested
|
||||
if args.debug:
|
||||
os.environ["WEB_TOOLS_DEBUG"] = "true"
|
||||
print_info("Debug mode enabled for web tools")
|
||||
|
||||
# Create tester
|
||||
tester = WebToolsTester(
|
||||
verbose=args.verbose,
|
||||
test_llm=not args.no_llm
|
||||
)
|
||||
|
||||
# Run tests
|
||||
success = await tester.run_all_tests()
|
||||
|
||||
# Exit with appropriate code
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Add table
Add a link
Reference in a new issue