The architecture has been updated

This commit is contained in:
Skyber_2 2026-03-31 23:31:36 +03:00
parent 805f7a017e
commit a01257ead9
1119 changed files with 226 additions and 352 deletions

View file

@ -0,0 +1,132 @@
#!/usr/bin/env python3
"""
Test script for batch runner
This script tests the batch runner with a small sample dataset
to verify functionality before running large batches.
"""
import pytest
pytestmark = pytest.mark.integration
import json
import shutil
from pathlib import Path
def create_test_dataset():
"""Create a small test dataset."""
test_file = Path("tests/test_dataset.jsonl")
test_file.parent.mkdir(exist_ok=True)
prompts = [
{"prompt": "What is 2 + 2?"},
{"prompt": "What is the capital of France?"},
{"prompt": "Explain what Python is in one sentence."},
]
with open(test_file, 'w') as f:
for prompt in prompts:
f.write(json.dumps(prompt, ensure_ascii=False) + "\n")
print(f"✅ Created test dataset: {test_file}")
return test_file
def cleanup_test_run(run_name):
"""Clean up test run output."""
output_dir = Path("data") / run_name
if output_dir.exists():
shutil.rmtree(output_dir)
print(f"🗑️ Cleaned up test output: {output_dir}")
def verify_output(run_name):
"""Verify that output files were created correctly."""
output_dir = Path("data") / run_name
# Check directory exists
if not output_dir.exists():
print(f"❌ Output directory not found: {output_dir}")
return False
# Check for checkpoint
checkpoint_file = output_dir / "checkpoint.json"
if not checkpoint_file.exists():
print(f"❌ Checkpoint file not found: {checkpoint_file}")
return False
# Check for statistics
stats_file = output_dir / "statistics.json"
if not stats_file.exists():
print(f"❌ Statistics file not found: {stats_file}")
return False
# Check for batch files
batch_files = list(output_dir.glob("batch_*.jsonl"))
if not batch_files:
print(f"❌ No batch files found in: {output_dir}")
return False
print(f"✅ Output verification passed:")
print(f" - Checkpoint: {checkpoint_file}")
print(f" - Statistics: {stats_file}")
print(f" - Batch files: {len(batch_files)}")
# Load and display statistics
with open(stats_file) as f:
stats = json.load(f)
print(f"\n📊 Statistics Summary:")
print(f" - Total prompts: {stats['total_prompts']}")
print(f" - Total batches: {stats['total_batches']}")
print(f" - Duration: {stats['duration_seconds']}s")
if stats.get('tool_statistics'):
print(f" - Tool calls:")
for tool, tool_stats in stats['tool_statistics'].items():
print(f"{tool}: {tool_stats['count']} calls, {tool_stats['success_rate']:.1f}% success")
return True
def main():
"""Run the test."""
print("🧪 Batch Runner Test")
print("=" * 60)
run_name = "test_run"
# Clean up any previous test run
cleanup_test_run(run_name)
# Create test dataset
test_file = create_test_dataset()
print(f"\n📝 To run the test manually:")
print(f" python batch_runner.py \\")
print(f" --dataset_file={test_file} \\")
print(f" --batch_size=2 \\")
print(f" --run_name={run_name} \\")
print(f" --distribution=minimal \\")
print(f" --num_workers=2")
print(f"\n💡 Or test with different distributions:")
print(f" python batch_runner.py --list_distributions")
print(f"\n🔍 After running, you can verify output with:")
print(f" python tests/test_batch_runner.py --verify")
# Note: We don't actually run the batch runner here to avoid API calls during testing
# Users should run it manually with their API keys configured
if __name__ == "__main__":
import sys
if "--verify" in sys.argv:
run_name = "test_run"
verify_output(run_name)
else:
main()

View file

@ -0,0 +1,440 @@
#!/usr/bin/env python3
"""
Test script to verify checkpoint behavior in batch_runner.py
This script simulates batch processing with intentional failures to test:
1. Whether checkpoints are saved incrementally during processing
2. Whether resume functionality works correctly after interruption
3. Whether data integrity is maintained across checkpoint cycles
Usage:
# Test current implementation
python tests/test_checkpoint_resumption.py --test_current
# Test after fix is applied
python tests/test_checkpoint_resumption.py --test_fixed
# Run full comparison
python tests/test_checkpoint_resumption.py --compare
"""
import pytest
pytestmark = pytest.mark.integration
import json
import os
import shutil
import sys
import time
from pathlib import Path
from typing import List, Dict, Any
import traceback
# Add project root to path to import batch_runner
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
def create_test_dataset(num_prompts: int = 20) -> Path:
"""Create a small test dataset for checkpoint testing."""
test_data_dir = Path("tests/test_data")
test_data_dir.mkdir(parents=True, exist_ok=True)
dataset_file = test_data_dir / "checkpoint_test_dataset.jsonl"
with open(dataset_file, 'w', encoding='utf-8') as f:
for i in range(num_prompts):
entry = {
"prompt": f"Test prompt {i}: What is 2+2? Just answer briefly.",
"test_id": i
}
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"✅ Created test dataset: {dataset_file} ({num_prompts} prompts)")
return dataset_file
def monitor_checkpoint_during_run(checkpoint_file: Path, duration: int = 30) -> List[Dict[str, Any]]:
"""
Monitor checkpoint file during a batch run to see when it gets updated.
Args:
checkpoint_file: Path to checkpoint file to monitor
duration: How long to monitor (seconds)
Returns:
List of checkpoint snapshots with timestamps
"""
snapshots = []
start_time = time.time()
last_mtime = None
print(f"\n🔍 Monitoring checkpoint file: {checkpoint_file}")
print(f" Duration: {duration}s")
print("-" * 70)
while time.time() - start_time < duration:
if checkpoint_file.exists():
current_mtime = checkpoint_file.stat().st_mtime
# Check if file was modified
if last_mtime is None or current_mtime != last_mtime:
elapsed = time.time() - start_time
try:
with open(checkpoint_file, 'r') as f:
checkpoint_data = json.load(f)
snapshot = {
"elapsed_seconds": round(elapsed, 2),
"completed_count": len(checkpoint_data.get("completed_prompts", [])),
"completed_prompts": checkpoint_data.get("completed_prompts", [])[:5], # First 5 for display
"timestamp": checkpoint_data.get("last_updated")
}
snapshots.append(snapshot)
print(f"[{elapsed:6.2f}s] Checkpoint updated: {snapshot['completed_count']} prompts completed")
except Exception as e:
print(f"[{elapsed:6.2f}s] Error reading checkpoint: {e}")
last_mtime = current_mtime
else:
if len(snapshots) == 0:
print(f"[{time.time() - start_time:6.2f}s] Checkpoint file not yet created...")
time.sleep(0.5) # Check every 0.5 seconds
return snapshots
def _cleanup_test_artifacts(*paths):
"""Remove test-generated files and directories."""
for p in paths:
p = Path(p)
if p.is_dir():
shutil.rmtree(p, ignore_errors=True)
elif p.is_file():
p.unlink(missing_ok=True)
def test_current_implementation():
"""Test the current checkpoint implementation."""
print("\n" + "=" * 70)
print("TEST 1: Current Implementation - Checkpoint Timing")
print("=" * 70)
print("\n📝 Testing whether checkpoints are saved incrementally during run...")
# Setup
dataset_file = create_test_dataset(num_prompts=12)
run_name = "checkpoint_test_current"
output_dir = Path("data") / run_name
# Clean up any existing test data
if output_dir.exists():
shutil.rmtree(output_dir)
# Import here to avoid issues if module changes
from batch_runner import BatchRunner
checkpoint_file = output_dir / "checkpoint.json"
# Start monitoring in a separate process would be ideal, but for simplicity
# we'll just check before and after
print(f"\n▶️ Starting batch run...")
print(f" Dataset: {dataset_file}")
print(f" Batch size: 3 (4 batches total)")
print(f" Workers: 2")
print(f" Expected behavior: If incremental, checkpoint should update during run")
start_time = time.time()
try:
runner = BatchRunner(
dataset_file=str(dataset_file),
batch_size=3,
run_name=run_name,
distribution="default",
max_iterations=3, # Keep it short
model="claude-opus-4-20250514",
num_workers=2,
verbose=False
)
# Run with monitoring
import threading
snapshots = []
def monitor():
nonlocal snapshots
snapshots = monitor_checkpoint_during_run(checkpoint_file, duration=60)
monitor_thread = threading.Thread(target=monitor, daemon=True)
monitor_thread.start()
runner.run(resume=False)
monitor_thread.join(timeout=2)
except Exception as e:
print(f"❌ Error during run: {e}")
traceback.print_exc()
return False
finally:
_cleanup_test_artifacts(dataset_file, output_dir)
elapsed = time.time() - start_time
# Analyze results
print("\n" + "=" * 70)
print("📊 TEST RESULTS")
print("=" * 70)
print(f"Total run time: {elapsed:.2f}s")
print(f"Checkpoint updates observed: {len(snapshots)}")
if len(snapshots) == 0:
print("\n❌ ISSUE: No checkpoint updates observed during run")
print(" This suggests checkpoints are only saved at the end")
return False
elif len(snapshots) == 1:
print("\n⚠️ WARNING: Only 1 checkpoint update (likely at the end)")
print(" This confirms the bug - no incremental checkpointing")
return False
else:
print(f"\n✅ GOOD: Multiple checkpoint updates ({len(snapshots)}) observed")
print(" Checkpointing appears to be incremental")
# Show timeline
print("\n📈 Checkpoint Timeline:")
for i, snapshot in enumerate(snapshots, 1):
print(f" {i}. [{snapshot['elapsed_seconds']:6.2f}s] "
f"{snapshot['completed_count']} prompts completed")
return True
def test_interruption_and_resume():
"""Test that resume actually works after interruption."""
print("\n" + "=" * 70)
print("TEST 2: Interruption and Resume")
print("=" * 70)
print("\n📝 Testing whether resume works after manual interruption...")
# Setup
dataset_file = create_test_dataset(num_prompts=15)
run_name = "checkpoint_test_resume"
output_dir = Path("data") / run_name
# Clean up any existing test data
if output_dir.exists():
shutil.rmtree(output_dir)
from batch_runner import BatchRunner
checkpoint_file = output_dir / "checkpoint.json"
print(f"\n▶️ Starting first run (will process 5 prompts, then simulate interruption)...")
temp_dataset = Path("tests/test_data/checkpoint_test_resume_partial.jsonl")
try:
# Create a modified dataset with only first 5 prompts for initial run
with open(dataset_file, 'r') as f:
lines = f.readlines()[:5]
with open(temp_dataset, 'w') as f:
f.writelines(lines)
runner = BatchRunner(
dataset_file=str(temp_dataset),
batch_size=2,
run_name=run_name,
distribution="default",
max_iterations=3,
model="claude-opus-4-20250514",
num_workers=1,
verbose=False
)
runner.run(resume=False)
# Check checkpoint after first run
if not checkpoint_file.exists():
print("❌ ERROR: Checkpoint file not created after first run")
return False
with open(checkpoint_file, 'r') as f:
checkpoint_data = json.load(f)
initial_completed = len(checkpoint_data.get("completed_prompts", []))
print(f"✅ First run completed: {initial_completed} prompts saved to checkpoint")
# Now try to resume with full dataset
print(f"\n▶️ Starting resume run with full dataset (15 prompts)...")
runner2 = BatchRunner(
dataset_file=str(dataset_file),
batch_size=2,
run_name=run_name,
distribution="default",
max_iterations=3,
model="claude-opus-4-20250514",
num_workers=1,
verbose=False
)
runner2.run(resume=True)
# Check final checkpoint
with open(checkpoint_file, 'r') as f:
final_checkpoint = json.load(f)
final_completed = len(final_checkpoint.get("completed_prompts", []))
print("\n" + "=" * 70)
print("📊 TEST RESULTS")
print("=" * 70)
print(f"Initial completed: {initial_completed}")
print(f"Final completed: {final_completed}")
print(f"Expected: 15")
if final_completed == 15:
print("\n✅ PASS: Resume successfully completed all prompts")
return True
else:
print(f"\n❌ FAIL: Expected 15 completed, got {final_completed}")
return False
except Exception as e:
print(f"❌ Error during test: {e}")
traceback.print_exc()
return False
finally:
_cleanup_test_artifacts(dataset_file, temp_dataset, output_dir)
def test_simulated_crash():
"""Test behavior when process crashes mid-execution."""
print("\n" + "=" * 70)
print("TEST 3: Simulated Crash During Execution")
print("=" * 70)
print("\n📝 This test would require running in a subprocess and killing it...")
print(" Skipping for safety - manual testing recommended")
return None
def print_test_plan():
"""Print the detailed test and fix plan."""
print("\n" + "=" * 70)
print("CHECKPOINT FIX - DETAILED PLAN")
print("=" * 70)
print("""
📋 PROBLEM SUMMARY
------------------
Current implementation uses pool.map() which blocks until ALL batches complete.
Checkpoint is only saved after all batches finish (line 558-559).
If process crashes during batch processing:
- All progress is lost
- Resume does nothing (no incremental checkpoint was saved)
📋 PROPOSED SOLUTION
--------------------
Replace pool.map() with pool.imap_unordered() to get results as they complete.
Save checkpoint after EACH batch completes using a multiprocessing Lock.
Key changes:
1. Use Manager().Lock() for thread-safe checkpoint writes
2. Replace pool.map() with pool.imap_unordered()
3. Update checkpoint after each batch result
4. Maintain backward compatibility with existing checkpoints
📋 IMPLEMENTATION STEPS
-----------------------
1. Add Manager and Lock initialization before Pool creation
2. Pass shared checkpoint data and lock to workers (via Manager)
3. Replace pool.map() with pool.imap_unordered()
4. In result loop: save checkpoint after each batch
5. Add error handling for checkpoint write failures
📋 RISKS & MITIGATIONS
----------------------
Risk: Checkpoint file corruption if two processes write simultaneously
Mitigation: Use multiprocessing.Lock() for exclusive access
Risk: Performance impact from frequent checkpoint writes
Mitigation: Checkpoint writes are fast (small JSON), negligible impact
Risk: Breaking existing runs that are already checkpointed
Mitigation: Maintain checkpoint format, only change timing
Risk: Bugs in multiprocessing lock/manager code
Mitigation: Thorough testing with this test script
📋 TESTING STRATEGY
-------------------
1. Run test_current_implementation() - Confirm bug exists
2. Apply fix to batch_runner.py
3. Run test_current_implementation() again - Should see incremental updates
4. Run test_interruption_and_resume() - Verify resume works
5. Manual test: Start run, kill process mid-batch, resume
📋 ROLLBACK PLAN
----------------
If issues arise:
1. Git revert the changes
2. Original code is working (just missing incremental checkpoint)
3. No data corruption risk - checkpoints are write-only
""")
def main(
test_current: bool = False,
test_resume: bool = False,
test_crash: bool = False,
compare: bool = False,
show_plan: bool = False
):
"""
Run checkpoint behavior tests.
Args:
test_current: Test current implementation checkpoint timing
test_resume: Test interruption and resume functionality
test_crash: Test simulated crash scenario (manual)
compare: Run all tests and compare
show_plan: Show detailed fix plan
"""
if show_plan or (not any([test_current, test_resume, test_crash, compare])):
print_test_plan()
return
results = {}
if test_current or compare:
results['current'] = test_current_implementation()
if test_resume or compare:
results['resume'] = test_interruption_and_resume()
if test_crash or compare:
results['crash'] = test_simulated_crash()
# Summary
if results:
print("\n" + "=" * 70)
print("OVERALL TEST SUMMARY")
print("=" * 70)
for test_name, result in results.items():
if result is None:
status = "⏭️ SKIPPED"
elif result:
status = "✅ PASS"
else:
status = "❌ FAIL"
print(f"{status} - {test_name}")
if __name__ == "__main__":
import fire
fire.Fire(main)

View file

@ -0,0 +1,123 @@
"""Integration tests for the Daytona terminal backend.
Requires DAYTONA_API_KEY to be set. Run with:
TERMINAL_ENV=daytona pytest tests/integration/test_daytona_terminal.py -v
"""
import json
import os
import sys
from pathlib import Path
import pytest
pytestmark = pytest.mark.integration
# Skip entire module if no API key
if not os.getenv("DAYTONA_API_KEY"):
pytest.skip("DAYTONA_API_KEY not set", allow_module_level=True)
# Import terminal_tool via importlib to avoid tools/__init__.py side effects
import importlib.util
parent_dir = Path(__file__).parent.parent.parent
sys.path.insert(0, str(parent_dir))
spec = importlib.util.spec_from_file_location(
"terminal_tool", parent_dir / "tools" / "terminal_tool.py"
)
terminal_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(terminal_module)
terminal_tool = terminal_module.terminal_tool
cleanup_vm = terminal_module.cleanup_vm
@pytest.fixture(autouse=True)
def _force_daytona(monkeypatch):
monkeypatch.setenv("TERMINAL_ENV", "daytona")
monkeypatch.setenv("TERMINAL_CONTAINER_DISK", "10240")
monkeypatch.setenv("TERMINAL_CONTAINER_PERSISTENT", "false")
@pytest.fixture()
def task_id(request):
"""Provide a unique task_id and clean up the sandbox after the test."""
tid = f"daytona_test_{request.node.name}"
yield tid
cleanup_vm(tid)
def _run(command, task_id, **kwargs):
result = terminal_tool(command, task_id=task_id, **kwargs)
return json.loads(result)
class TestDaytonaBasic:
def test_echo(self, task_id):
r = _run("echo 'Hello from Daytona!'", task_id)
assert r["exit_code"] == 0
assert "Hello from Daytona!" in r["output"]
def test_python_version(self, task_id):
r = _run("python3 --version", task_id)
assert r["exit_code"] == 0
assert "Python" in r["output"]
def test_nonzero_exit(self, task_id):
r = _run("exit 42", task_id)
assert r["exit_code"] == 42
def test_os_info(self, task_id):
r = _run("uname -a", task_id)
assert r["exit_code"] == 0
assert "Linux" in r["output"]
class TestDaytonaFilesystem:
def test_write_and_read_file(self, task_id):
_run("echo 'test content' > /tmp/daytona_test.txt", task_id)
r = _run("cat /tmp/daytona_test.txt", task_id)
assert r["exit_code"] == 0
assert "test content" in r["output"]
def test_persistence_within_session(self, task_id):
_run("pip install cowsay 2>/dev/null", task_id, timeout=120)
r = _run('python3 -c "import cowsay; print(cowsay.__file__)"', task_id)
assert r["exit_code"] == 0
assert "cowsay" in r["output"]
class TestDaytonaPersistence:
def test_filesystem_survives_stop_and_resume(self):
"""Write a file, stop the sandbox, resume it, assert the file persists."""
task = "daytona_test_persist"
try:
# Enable persistence for this test
os.environ["TERMINAL_CONTAINER_PERSISTENT"] = "true"
# Write a marker file and stop the sandbox
_run("echo 'survive' > /tmp/persist_test.txt", task)
cleanup_vm(task) # stops (not deletes) because persistent=true
# Resume with the same task_id — file should still exist
r = _run("cat /tmp/persist_test.txt", task)
assert r["exit_code"] == 0
assert "survive" in r["output"]
finally:
# Force-delete so the sandbox doesn't leak
os.environ["TERMINAL_CONTAINER_PERSISTENT"] = "false"
cleanup_vm(task)
class TestDaytonaIsolation:
def test_different_tasks_isolated(self):
task_a = "daytona_test_iso_a"
task_b = "daytona_test_iso_b"
try:
_run("echo 'secret' > /tmp/isolated.txt", task_a)
r = _run("cat /tmp/isolated.txt 2>&1 || echo NOT_FOUND", task_b)
assert "secret" not in r["output"] or "NOT_FOUND" in r["output"]
finally:
cleanup_vm(task_a)
cleanup_vm(task_b)

View file

@ -0,0 +1,341 @@
"""Integration tests for Home Assistant (tool + gateway).
Spins up a real in-process fake HA server (HTTP + WebSocket) and exercises
the full adapter and tool handler paths over real TCP connections.
No mocks -- only real async I/O against a fake server.
Run with: uv run pytest tests/integration/test_ha_integration.py -v
"""
import asyncio
import pytest
pytestmark = pytest.mark.integration
from unittest.mock import AsyncMock
from gateway.config import Platform, PlatformConfig
from gateway.platforms.homeassistant import HomeAssistantAdapter
from tests.fakes.fake_ha_server import FakeHAServer, ENTITY_STATES
from tools.homeassistant_tool import (
_async_call_service,
_async_get_state,
_async_list_entities,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _adapter_for(server: FakeHAServer, **extra) -> HomeAssistantAdapter:
"""Create an adapter pointed at the fake server."""
config = PlatformConfig(
enabled=True,
token=server.token,
extra={"url": server.url, **extra},
)
return HomeAssistantAdapter(config)
# ---------------------------------------------------------------------------
# 1. Gateway -- WebSocket lifecycle
# ---------------------------------------------------------------------------
class TestGatewayWebSocket:
@pytest.mark.asyncio
async def test_connect_auth_subscribe(self):
"""Full WS handshake succeeds: auth_required -> auth -> auth_ok -> subscribe -> ACK."""
async with FakeHAServer() as server:
adapter = _adapter_for(server)
connected = await adapter.connect()
assert connected is True
assert adapter._running is True
assert adapter._ws is not None
assert not adapter._ws.closed
await adapter.disconnect()
@pytest.mark.asyncio
async def test_connect_auth_rejected(self):
"""connect() returns False when the server rejects auth."""
async with FakeHAServer() as server:
server.reject_auth = True
adapter = _adapter_for(server)
connected = await adapter.connect()
assert connected is False
@pytest.mark.asyncio
async def test_event_received_and_forwarded(self):
"""Server pushes event -> adapter calls handle_message with correct MessageEvent."""
async with FakeHAServer() as server:
adapter = _adapter_for(server)
adapter.handle_message = AsyncMock()
await adapter.connect()
# Push a state_changed event
await server.push_event({
"data": {
"entity_id": "light.bedroom",
"old_state": {"state": "off", "attributes": {}},
"new_state": {
"state": "on",
"attributes": {"friendly_name": "Bedroom Light"},
},
}
})
# Wait for the adapter to process it
for _ in range(50):
if adapter.handle_message.call_count > 0:
break
await asyncio.sleep(0.05)
assert adapter.handle_message.call_count == 1
msg_event = adapter.handle_message.call_args[0][0]
assert "Bedroom Light" in msg_event.text
assert "turned on" in msg_event.text
assert msg_event.source.platform == Platform.HOMEASSISTANT
await adapter.disconnect()
@pytest.mark.asyncio
async def test_event_filtering_ignores_unwatched(self):
"""Events outside watch_domains are silently dropped."""
async with FakeHAServer() as server:
adapter = _adapter_for(server, watch_domains=["climate"])
adapter.handle_message = AsyncMock()
await adapter.connect()
# Push a light event (not in watch_domains)
await server.push_event({
"data": {
"entity_id": "light.bedroom",
"old_state": {"state": "off", "attributes": {}},
"new_state": {
"state": "on",
"attributes": {"friendly_name": "Bedroom Light"},
},
}
})
await asyncio.sleep(0.5)
assert adapter.handle_message.call_count == 0
await adapter.disconnect()
@pytest.mark.asyncio
async def test_disconnect_closes_cleanly(self):
"""disconnect() cancels listener and closes WebSocket."""
async with FakeHAServer() as server:
adapter = _adapter_for(server)
await adapter.connect()
ws_ref = adapter._ws
await adapter.disconnect()
assert adapter._running is False
assert adapter._listen_task is None
assert adapter._ws is None
# The original WS reference should be closed
assert ws_ref.closed
# ---------------------------------------------------------------------------
# 2. REST tool handlers (real HTTP against fake server)
# ---------------------------------------------------------------------------
class TestToolRest:
"""Call the async tool functions directly against the fake server.
Note: we call ``_async_*`` instead of the sync ``_handle_*`` wrappers
because the sync wrappers use ``_run_async`` which blocks the event
loop, deadlocking with the in-process fake server. The async functions
are the real logic; the sync wrappers are trivial bridge code already
covered by unit tests.
"""
@pytest.mark.asyncio
async def test_list_entities_returns_all(self, monkeypatch):
"""_async_list_entities returns all entities from the fake server."""
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
result = await _async_list_entities()
assert result["count"] == len(ENTITY_STATES)
ids = {e["entity_id"] for e in result["entities"]}
assert "light.bedroom" in ids
assert "climate.thermostat" in ids
@pytest.mark.asyncio
async def test_list_entities_domain_filter(self, monkeypatch):
"""Domain filter is applied after fetching from server."""
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
result = await _async_list_entities(domain="light")
assert result["count"] == 2
for e in result["entities"]:
assert e["entity_id"].startswith("light.")
@pytest.mark.asyncio
async def test_get_state_single_entity(self, monkeypatch):
"""_async_get_state returns full entity details."""
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
result = await _async_get_state("light.bedroom")
assert result["entity_id"] == "light.bedroom"
assert result["state"] == "on"
assert result["attributes"]["brightness"] == 200
assert result["last_changed"] is not None
@pytest.mark.asyncio
async def test_get_state_not_found(self, monkeypatch):
"""Non-existent entity raises an aiohttp error (404)."""
import aiohttp as _aiohttp
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
with pytest.raises(_aiohttp.ClientResponseError) as exc_info:
await _async_get_state("light.nonexistent")
assert exc_info.value.status == 404
@pytest.mark.asyncio
async def test_call_service_turn_on(self, monkeypatch):
"""_async_call_service sends correct payload and server records it."""
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
result = await _async_call_service(
domain="light",
service="turn_on",
entity_id="light.bedroom",
data={"brightness": 255},
)
assert result["success"] is True
assert result["service"] == "light.turn_on"
assert len(result["affected_entities"]) == 1
assert result["affected_entities"][0]["state"] == "on"
# Verify fake server recorded the call
assert len(server.received_service_calls) == 1
call = server.received_service_calls[0]
assert call["domain"] == "light"
assert call["service"] == "turn_on"
assert call["data"]["entity_id"] == "light.bedroom"
assert call["data"]["brightness"] == 255
# ---------------------------------------------------------------------------
# 3. send() -- REST notification
# ---------------------------------------------------------------------------
class TestSendNotification:
@pytest.mark.asyncio
async def test_send_notification_delivered(self):
"""Adapter send() delivers notification to fake server REST endpoint."""
async with FakeHAServer() as server:
adapter = _adapter_for(server)
result = await adapter.send("ha_events", "Test notification from agent")
assert result.success is True
assert len(server.received_notifications) == 1
notif = server.received_notifications[0]
assert notif["title"] == "Hermes Agent"
assert notif["message"] == "Test notification from agent"
@pytest.mark.asyncio
async def test_send_auth_failure(self):
"""send() returns failure when token is wrong."""
async with FakeHAServer() as server:
config = PlatformConfig(
enabled=True,
token="wrong-token",
extra={"url": server.url},
)
adapter = HomeAssistantAdapter(config)
result = await adapter.send("ha_events", "Should fail")
assert result.success is False
assert "401" in result.error
# ---------------------------------------------------------------------------
# 4. Auth and error cases
# ---------------------------------------------------------------------------
class TestAuthAndErrors:
@pytest.mark.asyncio
async def test_rest_unauthorized(self, monkeypatch):
"""Async function raises on 401 when token is wrong."""
import aiohttp as _aiohttp
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", "bad-token",
)
with pytest.raises(_aiohttp.ClientResponseError) as exc_info:
await _async_list_entities()
assert exc_info.value.status == 401
@pytest.mark.asyncio
async def test_rest_server_error(self, monkeypatch):
"""Async function raises on 500 response."""
import aiohttp as _aiohttp
async with FakeHAServer() as server:
server.force_500 = True
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
with pytest.raises(_aiohttp.ClientResponseError) as exc_info:
await _async_list_entities()
assert exc_info.value.status == 500

View file

@ -0,0 +1,301 @@
#!/usr/bin/env python3
"""
Test Modal Terminal Tool
This script tests that the Modal terminal backend is correctly configured
and can execute commands in Modal sandboxes.
Usage:
# Run with Modal backend
TERMINAL_ENV=modal python tests/test_modal_terminal.py
# Or run directly (will use whatever TERMINAL_ENV is set in .env)
python tests/test_modal_terminal.py
"""
import pytest
pytestmark = pytest.mark.integration
import os
import sys
import json
from pathlib import Path
# Try to load .env file if python-dotenv is available
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
# Manually load .env if dotenv not available
env_file = Path(__file__).parent.parent.parent / ".env"
if env_file.exists():
with open(env_file) as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
# Remove quotes if present
value = value.strip().strip('"').strip("'")
os.environ.setdefault(key.strip(), value)
# Add project root to path for imports
parent_dir = Path(__file__).parent.parent.parent
sys.path.insert(0, str(parent_dir))
# Import terminal_tool module directly using importlib to avoid tools/__init__.py
import importlib.util
terminal_tool_path = parent_dir / "tools" / "terminal_tool.py"
spec = importlib.util.spec_from_file_location("terminal_tool", terminal_tool_path)
terminal_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(terminal_module)
terminal_tool = terminal_module.terminal_tool
check_terminal_requirements = terminal_module.check_terminal_requirements
_get_env_config = terminal_module._get_env_config
cleanup_vm = terminal_module.cleanup_vm
get_active_environments_info = terminal_module.get_active_environments_info
def test_modal_requirements():
"""Test that Modal requirements are met."""
print("\n" + "=" * 60)
print("TEST 1: Modal Requirements Check")
print("=" * 60)
config = _get_env_config()
print(f"Current TERMINAL_ENV: {config['env_type']}")
print(f"Modal image: {config['modal_image']}")
# Check for Modal authentication
modal_token = os.getenv("MODAL_TOKEN_ID")
modal_toml = Path.home() / ".modal.toml"
print(f"\nModal authentication:")
print(f" MODAL_TOKEN_ID env var: {'✅ Set' if modal_token else '❌ Not set'}")
print(f" ~/.modal.toml file: {'✅ Exists' if modal_toml.exists() else '❌ Not found'}")
if config['env_type'] != 'modal':
print(f"\n⚠️ TERMINAL_ENV is '{config['env_type']}', not 'modal'")
print(" Set TERMINAL_ENV=modal in .env or export it to test Modal backend")
return False
requirements_met = check_terminal_requirements()
print(f"\nRequirements check: {'✅ Passed' if requirements_met else '❌ Failed'}")
return requirements_met
def test_simple_command():
"""Test executing a simple command."""
print("\n" + "=" * 60)
print("TEST 2: Simple Command Execution")
print("=" * 60)
test_task_id = "modal_test_simple"
print("Executing: echo 'Hello from Modal!'")
result = terminal_tool("echo 'Hello from Modal!'", task_id=test_task_id)
result_json = json.loads(result)
print(f"\nResult:")
print(f" Output: {result_json.get('output', '')[:200]}")
print(f" Exit code: {result_json.get('exit_code')}")
print(f" Error: {result_json.get('error')}")
success = result_json.get('exit_code') == 0 and 'Hello from Modal!' in result_json.get('output', '')
print(f"\nTest: {'✅ Passed' if success else '❌ Failed'}")
# Cleanup
cleanup_vm(test_task_id)
return success
def test_python_execution():
"""Test executing Python code in Modal."""
print("\n" + "=" * 60)
print("TEST 3: Python Execution")
print("=" * 60)
test_task_id = "modal_test_python"
python_cmd = 'python3 -c "import sys; print(f\'Python {sys.version}\')"'
print(f"Executing: {python_cmd}")
result = terminal_tool(python_cmd, task_id=test_task_id)
result_json = json.loads(result)
print(f"\nResult:")
print(f" Output: {result_json.get('output', '')[:200]}")
print(f" Exit code: {result_json.get('exit_code')}")
print(f" Error: {result_json.get('error')}")
success = result_json.get('exit_code') == 0 and 'Python' in result_json.get('output', '')
print(f"\nTest: {'✅ Passed' if success else '❌ Failed'}")
# Cleanup
cleanup_vm(test_task_id)
return success
def test_pip_install():
"""Test installing a package with pip in Modal."""
print("\n" + "=" * 60)
print("TEST 4: Pip Install Test")
print("=" * 60)
test_task_id = "modal_test_pip"
# Install a small package and verify
print("Executing: pip install --break-system-packages cowsay && python3 -c \"import cowsay; cowsay.cow('Modal works!')\"")
result = terminal_tool(
"pip install --break-system-packages cowsay && python3 -c \"import cowsay; cowsay.cow('Modal works!')\"",
task_id=test_task_id,
timeout=120
)
result_json = json.loads(result)
print(f"\nResult:")
output = result_json.get('output', '')
print(f" Output (last 500 chars): ...{output[-500:] if len(output) > 500 else output}")
print(f" Exit code: {result_json.get('exit_code')}")
print(f" Error: {result_json.get('error')}")
success = result_json.get('exit_code') == 0 and 'Modal works!' in result_json.get('output', '')
print(f"\nTest: {'✅ Passed' if success else '❌ Failed'}")
# Cleanup
cleanup_vm(test_task_id)
return success
def test_filesystem_persistence():
"""Test that filesystem persists between commands in the same task."""
print("\n" + "=" * 60)
print("TEST 5: Filesystem Persistence")
print("=" * 60)
test_task_id = "modal_test_persist"
# Create a file
print("Step 1: Creating test file...")
result1 = terminal_tool("echo 'persistence test' > /tmp/modal_test.txt", task_id=test_task_id)
result1_json = json.loads(result1)
print(f" Exit code: {result1_json.get('exit_code')}")
# Read the file back
print("Step 2: Reading test file...")
result2 = terminal_tool("cat /tmp/modal_test.txt", task_id=test_task_id)
result2_json = json.loads(result2)
print(f" Output: {result2_json.get('output', '')}")
print(f" Exit code: {result2_json.get('exit_code')}")
success = (
result1_json.get('exit_code') == 0 and
result2_json.get('exit_code') == 0 and
'persistence test' in result2_json.get('output', '')
)
print(f"\nTest: {'✅ Passed' if success else '❌ Failed'}")
# Cleanup
cleanup_vm(test_task_id)
return success
def test_environment_isolation():
"""Test that different task_ids get isolated environments."""
print("\n" + "=" * 60)
print("TEST 6: Environment Isolation")
print("=" * 60)
task1 = "modal_test_iso_1"
task2 = "modal_test_iso_2"
# Create file in task1
print("Step 1: Creating file in task1...")
result1 = terminal_tool("echo 'task1 data' > /tmp/isolated.txt", task_id=task1)
# Try to read from task2 (should not exist)
print("Step 2: Trying to read file from task2 (should not exist)...")
result2 = terminal_tool("cat /tmp/isolated.txt 2>&1 || echo 'FILE_NOT_FOUND'", task_id=task2)
result2_json = json.loads(result2)
# The file should either not exist or be empty in task2
output = result2_json.get('output', '')
isolated = 'task1 data' not in output or 'FILE_NOT_FOUND' in output or 'No such file' in output
print(f" Task2 output: {output[:200]}")
print(f"\nTest: {'✅ Passed (environments isolated)' if isolated else '❌ Failed (environments NOT isolated)'}")
# Cleanup
cleanup_vm(task1)
cleanup_vm(task2)
return isolated
def main():
"""Run all Modal terminal tests."""
print("🧪 Modal Terminal Tool Test Suite")
print("=" * 60)
# Check current config
config = _get_env_config()
print(f"\nCurrent configuration:")
print(f" TERMINAL_ENV: {config['env_type']}")
print(f" TERMINAL_MODAL_IMAGE: {config['modal_image']}")
print(f" TERMINAL_TIMEOUT: {config['timeout']}s")
if config['env_type'] != 'modal':
print(f"\n⚠️ WARNING: TERMINAL_ENV is set to '{config['env_type']}', not 'modal'")
print(" To test Modal specifically, set TERMINAL_ENV=modal")
response = input("\n Continue testing with current backend? (y/n): ")
if response.lower() != 'y':
print("Aborting.")
return
results = {}
# Run tests
results['requirements'] = test_modal_requirements()
if not results['requirements']:
print("\n❌ Requirements not met. Cannot continue with other tests.")
return
results['simple_command'] = test_simple_command()
results['python_execution'] = test_python_execution()
results['pip_install'] = test_pip_install()
results['filesystem_persistence'] = test_filesystem_persistence()
results['environment_isolation'] = test_environment_isolation()
# Summary
print("\n" + "=" * 60)
print("TEST SUMMARY")
print("=" * 60)
passed = sum(1 for v in results.values() if v)
total = len(results)
for test_name, passed_test in results.items():
status = "✅ PASSED" if passed_test else "❌ FAILED"
print(f" {test_name}: {status}")
print(f"\nTotal: {passed}/{total} tests passed")
# Show active environments
env_info = get_active_environments_info()
print(f"\nActive environments after tests: {env_info['count']}")
if env_info['count'] > 0:
print(f" Task IDs: {env_info['task_ids']}")
return passed == total
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View file

@ -0,0 +1,611 @@
"""Integration tests for Discord voice channel audio flow.
Uses real NaCl encryption and Opus codec (no mocks for crypto/codec).
Does NOT require a Discord connection tests the VoiceReceiver
packet processing pipeline end-to-end.
Requires: PyNaCl>=1.5.0, discord.py[voice] (opus codec)
"""
import struct
import time
import pytest
pytestmark = pytest.mark.integration
# Skip entire module if voice deps are missing
pytest.importorskip("nacl.secret", reason="PyNaCl required for voice integration tests")
discord = pytest.importorskip("discord", reason="discord.py required for voice integration tests")
import nacl.secret
try:
if not discord.opus.is_loaded():
import ctypes.util
opus_path = ctypes.util.find_library("opus")
if not opus_path:
import sys
for p in ("/opt/homebrew/lib/libopus.dylib", "/usr/local/lib/libopus.dylib"):
import os
if os.path.isfile(p):
opus_path = p
break
if opus_path:
discord.opus.load_opus(opus_path)
OPUS_AVAILABLE = discord.opus.is_loaded()
except Exception:
OPUS_AVAILABLE = False
from types import SimpleNamespace
from unittest.mock import MagicMock
from gateway.platforms.discord import VoiceReceiver
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_secret_key():
"""Generate a random 32-byte key."""
import os
return os.urandom(32)
def _build_encrypted_rtp_packet(secret_key, opus_payload, ssrc=100, seq=1, timestamp=960):
"""Build a real NaCl-encrypted RTP packet matching Discord's format.
Format: RTP header (12 bytes) + encrypted(opus) + 4-byte nonce
Encryption: aead_xchacha20_poly1305 with RTP header as AAD.
"""
# RTP header: version=2, payload_type=0x78, no extension, no CSRC
header = struct.pack(">BBHII", 0x80, 0x78, seq, timestamp, ssrc)
# Encrypt with NaCl AEAD
box = nacl.secret.Aead(secret_key)
nonce_counter = struct.pack(">I", seq) # 4-byte counter as nonce seed
# Full 24-byte nonce: counter in first 4 bytes, rest zeros
full_nonce = nonce_counter + b'\x00' * 20
enc_msg = box.encrypt(opus_payload, header, full_nonce)
ciphertext = enc_msg.ciphertext # without nonce prefix
# Discord format: header + ciphertext + 4-byte nonce
return header + ciphertext + nonce_counter
def _make_voice_receiver(secret_key, dave_session=None, bot_ssrc=9999,
allowed_user_ids=None, members=None):
"""Create a VoiceReceiver with real secret key."""
vc = MagicMock()
vc._connection.secret_key = list(secret_key)
vc._connection.dave_session = dave_session
vc._connection.ssrc = bot_ssrc
vc._connection.add_socket_listener = MagicMock()
vc._connection.remove_socket_listener = MagicMock()
vc._connection.hook = None
vc.user = SimpleNamespace(id=bot_ssrc)
vc.channel = MagicMock()
vc.channel.members = members or []
receiver = VoiceReceiver(vc, allowed_user_ids=allowed_user_ids)
receiver.start()
return receiver
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestRealNaClDecrypt:
"""End-to-end: real NaCl encrypt → _on_packet decrypt → buffer."""
def test_valid_encrypted_packet_buffered(self):
"""Real NaCl encrypted packet → decrypted → buffered."""
key = _make_secret_key()
opus_silence = b'\xf8\xff\xfe'
receiver = _make_voice_receiver(key)
packet = _build_encrypted_rtp_packet(key, opus_silence, ssrc=100)
receiver._on_packet(packet)
assert 100 in receiver._buffers
assert len(receiver._buffers[100]) > 0
def test_wrong_key_packet_dropped(self):
"""Packet encrypted with wrong key → NaCl fails → not buffered."""
real_key = _make_secret_key()
wrong_key = _make_secret_key()
opus_silence = b'\xf8\xff\xfe'
receiver = _make_voice_receiver(real_key)
packet = _build_encrypted_rtp_packet(wrong_key, opus_silence, ssrc=100)
receiver._on_packet(packet)
assert len(receiver._buffers.get(100, b"")) == 0
def test_bot_ssrc_ignored(self):
"""Packet from bot's own SSRC → ignored."""
key = _make_secret_key()
receiver = _make_voice_receiver(key, bot_ssrc=9999)
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=9999)
receiver._on_packet(packet)
assert len(receiver._buffers) == 0
def test_multiple_packets_accumulate(self):
"""Multiple valid packets → buffer grows."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
for seq in range(1, 6):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver._on_packet(packet)
assert 100 in receiver._buffers
buf_size = len(receiver._buffers[100])
assert buf_size > 0, "Multiple packets should accumulate in buffer"
def test_different_ssrcs_separate_buffers(self):
"""Packets from different SSRCs → separate buffers."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
for ssrc in [100, 200, 300]:
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=ssrc)
receiver._on_packet(packet)
assert len(receiver._buffers) == 3
for ssrc in [100, 200, 300]:
assert ssrc in receiver._buffers
class TestRealNaClWithDAVE:
"""NaCl decrypt + DAVE passthrough scenarios with real crypto."""
def test_dave_unknown_ssrc_passthrough(self):
"""DAVE enabled but SSRC unknown → skip DAVE, buffer audio."""
key = _make_secret_key()
dave = MagicMock() # DAVE session present but SSRC not mapped
receiver = _make_voice_receiver(key, dave_session=dave)
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
receiver._on_packet(packet)
# DAVE decrypt not called (SSRC unknown)
dave.decrypt.assert_not_called()
# Audio still buffered via passthrough
assert 100 in receiver._buffers
assert len(receiver._buffers[100]) > 0
def test_dave_unencrypted_error_passthrough(self):
"""DAVE raises 'Unencrypted' → use NaCl-decrypted data as-is."""
key = _make_secret_key()
dave = MagicMock()
dave.decrypt.side_effect = Exception(
"DecryptionFailed(UnencryptedWhenPassthroughDisabled)"
)
receiver = _make_voice_receiver(key, dave_session=dave)
receiver.map_ssrc(100, 42)
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
receiver._on_packet(packet)
# DAVE was called but failed → passthrough
dave.decrypt.assert_called_once()
assert 100 in receiver._buffers
assert len(receiver._buffers[100]) > 0
def test_dave_real_error_drops(self):
"""DAVE raises non-Unencrypted error → packet dropped."""
key = _make_secret_key()
dave = MagicMock()
dave.decrypt.side_effect = Exception("KeyRotationFailed")
receiver = _make_voice_receiver(key, dave_session=dave)
receiver.map_ssrc(100, 42)
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
receiver._on_packet(packet)
assert len(receiver._buffers.get(100, b"")) == 0
class TestFullVoiceFlow:
"""End-to-end: encrypt → receive → buffer → silence detect → complete."""
def test_single_utterance_flow(self):
"""Encrypt packets → buffer → silence → check_silence returns utterance."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
receiver.map_ssrc(100, 42)
# Send enough packets to exceed MIN_SPEECH_DURATION (0.5s)
# At 48kHz stereo 16-bit, each Opus silence frame decodes to ~3840 bytes
# Need 96000 bytes = ~25 frames
for seq in range(1, 30):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver._on_packet(packet)
# Simulate silence by setting last_packet_time in the past
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 1
user_id, pcm_data = completed[0]
assert user_id == 42
assert len(pcm_data) > 0
def test_utterance_with_ssrc_automap(self):
"""No SPEAKING event → auto-map sole allowed user → utterance processed."""
key = _make_secret_key()
members = [
SimpleNamespace(id=9999, name="Bot"),
SimpleNamespace(id=42, name="Alice"),
]
receiver = _make_voice_receiver(
key, allowed_user_ids={"42"}, members=members
)
# No map_ssrc call — simulating missing SPEAKING event
for seq in range(1, 30):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver._on_packet(packet)
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 1
assert completed[0][0] == 42 # auto-mapped to sole allowed user
def test_pause_blocks_during_playback(self):
"""Pause receiver → packets ignored → resume → packets accepted."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
# Pause (echo prevention during TTS playback)
receiver.pause()
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
receiver._on_packet(packet)
assert len(receiver._buffers.get(100, b"")) == 0
# Resume
receiver.resume()
receiver._on_packet(packet)
assert 100 in receiver._buffers
assert len(receiver._buffers[100]) > 0
def test_corrupted_packet_ignored(self):
"""Corrupted/truncated packet → silently ignored."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
# Too short
receiver._on_packet(b"\x00" * 5)
assert len(receiver._buffers) == 0
# Wrong RTP version
bad_header = struct.pack(">BBHII", 0x00, 0x78, 1, 960, 100)
receiver._on_packet(bad_header + b"\x00" * 20)
assert len(receiver._buffers) == 0
# Wrong payload type
bad_pt = struct.pack(">BBHII", 0x80, 0x00, 1, 960, 100)
receiver._on_packet(bad_pt + b"\x00" * 20)
assert len(receiver._buffers) == 0
def test_stop_cleans_everything(self):
"""stop() clears all state cleanly."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
receiver.map_ssrc(100, 42)
for seq in range(1, 10):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver._on_packet(packet)
assert len(receiver._buffers[100]) > 0
receiver.stop()
assert receiver._running is False
assert len(receiver._buffers) == 0
assert len(receiver._ssrc_to_user) == 0
assert len(receiver._decoders) == 0
class TestSPEAKINGHook:
"""SPEAKING event hook correctly maps SSRC to user_id."""
def test_speaking_hook_installed(self):
"""start() installs speaking hook on connection."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
conn = receiver._vc._connection
# hook should be set (wrapped)
assert conn.hook is not None
def test_map_ssrc_via_speaking(self):
"""SPEAKING op 5 event maps SSRC to user_id."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
receiver.map_ssrc(500, 12345)
assert receiver._ssrc_to_user[500] == 12345
def test_map_ssrc_overwrites(self):
"""New SPEAKING event for same SSRC overwrites old mapping."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
receiver.map_ssrc(500, 111)
receiver.map_ssrc(500, 222)
assert receiver._ssrc_to_user[500] == 222
def test_speaking_mapped_audio_processed(self):
"""After SSRC is mapped, audio from that SSRC gets correct user_id."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
receiver.map_ssrc(100, 42)
for seq in range(1, 30):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver._on_packet(packet)
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 1
assert completed[0][0] == 42
class TestAuthFiltering:
"""Only allowed users' audio should be processed."""
def test_allowed_user_audio_processed(self):
"""Allowed user's utterance is returned by check_silence."""
key = _make_secret_key()
members = [
SimpleNamespace(id=9999, name="Bot"),
SimpleNamespace(id=42, name="Alice"),
]
receiver = _make_voice_receiver(
key, allowed_user_ids={"42"}, members=members,
)
receiver.map_ssrc(100, 42)
for seq in range(1, 30):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver._on_packet(packet)
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 1
assert completed[0][0] == 42
def test_automap_rejects_unallowed_user(self):
"""Auto-map refuses to map SSRC to user not in allowed list."""
key = _make_secret_key()
members = [
SimpleNamespace(id=9999, name="Bot"),
SimpleNamespace(id=42, name="Alice"),
]
receiver = _make_voice_receiver(
key, allowed_user_ids={"99"}, # Alice not allowed
members=members,
)
# No map_ssrc — SSRC unknown, auto-map should reject
for seq in range(1, 30):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver._on_packet(packet)
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 0
def test_empty_allowlist_allows_all(self):
"""Empty allowed_user_ids means no restriction."""
key = _make_secret_key()
members = [
SimpleNamespace(id=9999, name="Bot"),
SimpleNamespace(id=42, name="Alice"),
]
receiver = _make_voice_receiver(
key, allowed_user_ids=None, members=members,
)
for seq in range(1, 30):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver._on_packet(packet)
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
# Auto-mapped to sole non-bot member
assert len(completed) == 1
assert completed[0][0] == 42
class TestRejoinFlow:
"""Leave and rejoin: state cleanup and fresh receiver."""
def test_stop_then_new_receiver_clean_state(self):
"""After stop(), a new receiver starts with empty state."""
key = _make_secret_key()
receiver1 = _make_voice_receiver(key)
receiver1.map_ssrc(100, 42)
for seq in range(1, 10):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver1._on_packet(packet)
assert len(receiver1._buffers[100]) > 0
receiver1.stop()
# New receiver (simulates rejoin)
receiver2 = _make_voice_receiver(key)
assert len(receiver2._buffers) == 0
assert len(receiver2._ssrc_to_user) == 0
assert len(receiver2._decoders) == 0
def test_rejoin_new_ssrc_works(self):
"""After rejoin, user may get new SSRC — still works."""
key = _make_secret_key()
receiver1 = _make_voice_receiver(key)
receiver1.map_ssrc(100, 42) # old SSRC
receiver1.stop()
receiver2 = _make_voice_receiver(key)
receiver2.map_ssrc(200, 42) # new SSRC after rejoin
for seq in range(1, 30):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=200, seq=seq, timestamp=960 * seq
)
receiver2._on_packet(packet)
receiver2._last_packet_time[200] = time.monotonic() - 3.0
completed = receiver2.check_silence()
assert len(completed) == 1
assert completed[0][0] == 42
def test_rejoin_without_speaking_event_automap(self):
"""Rejoin without SPEAKING event — auto-map sole allowed user."""
key = _make_secret_key()
members = [
SimpleNamespace(id=9999, name="Bot"),
SimpleNamespace(id=42, name="Alice"),
]
# First session
receiver1 = _make_voice_receiver(
key, allowed_user_ids={"42"}, members=members,
)
receiver1.stop()
# Rejoin — new key (Discord may assign new secret_key)
new_key = _make_secret_key()
receiver2 = _make_voice_receiver(
new_key, allowed_user_ids={"42"}, members=members,
)
# No map_ssrc — simulating missing SPEAKING event
for seq in range(1, 30):
packet = _build_encrypted_rtp_packet(
new_key, b'\xf8\xff\xfe', ssrc=300, seq=seq, timestamp=960 * seq
)
receiver2._on_packet(packet)
receiver2._last_packet_time[300] = time.monotonic() - 3.0
completed = receiver2.check_silence()
assert len(completed) == 1
assert completed[0][0] == 42
class TestMultiGuildIsolation:
"""Each guild has independent voice state."""
def test_separate_receivers_independent(self):
"""Two receivers (different guilds) don't interfere."""
key1 = _make_secret_key()
key2 = _make_secret_key()
receiver1 = _make_voice_receiver(key1, bot_ssrc=1111)
receiver2 = _make_voice_receiver(key2, bot_ssrc=2222)
receiver1.map_ssrc(100, 42)
receiver2.map_ssrc(200, 99)
# Send to receiver1
for seq in range(1, 10):
packet = _build_encrypted_rtp_packet(
key1, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver1._on_packet(packet)
# receiver2 should be empty
assert len(receiver2._buffers) == 0
assert 100 in receiver1._buffers
def test_stop_one_doesnt_affect_other(self):
"""Stopping one receiver doesn't affect another."""
key1 = _make_secret_key()
key2 = _make_secret_key()
receiver1 = _make_voice_receiver(key1)
receiver2 = _make_voice_receiver(key2)
receiver1.map_ssrc(100, 42)
receiver2.map_ssrc(200, 99)
for seq in range(1, 10):
packet = _build_encrypted_rtp_packet(
key2, b'\xf8\xff\xfe', ssrc=200, seq=seq, timestamp=960 * seq
)
receiver2._on_packet(packet)
receiver1.stop()
# receiver2 still has data
assert receiver2._running is True
assert len(receiver2._buffers[200]) > 0
class TestEchoPreventionFlow:
"""Receiver pause/resume during TTS playback prevents echo."""
def test_audio_during_pause_ignored(self):
"""Audio arriving while paused is completely ignored."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
receiver.map_ssrc(100, 42)
receiver.pause()
for seq in range(1, 30):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver._on_packet(packet)
assert len(receiver._buffers.get(100, b"")) == 0
def test_audio_after_resume_processed(self):
"""Audio arriving after resume is processed normally."""
key = _make_secret_key()
receiver = _make_voice_receiver(key)
receiver.map_ssrc(100, 42)
# Pause → send packets → resume → send more packets
receiver.pause()
for seq in range(1, 5):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver._on_packet(packet)
assert len(receiver._buffers.get(100, b"")) == 0
receiver.resume()
for seq in range(5, 35):
packet = _build_encrypted_rtp_packet(
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
)
receiver._on_packet(packet)
assert len(receiver._buffers[100]) > 0
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 1
assert completed[0][0] == 42

View file

@ -0,0 +1,628 @@
#!/usr/bin/env python3
"""
Comprehensive Test Suite for Web Tools Module
This script tests all web tools functionality to ensure they work correctly.
Run this after any updates to the web_tools.py module or backend libraries.
Usage:
python test_web_tools.py # Run all tests
python test_web_tools.py --no-llm # Skip LLM processing tests
python test_web_tools.py --verbose # Show detailed output
Requirements:
- PARALLEL_API_KEY or FIRECRAWL_API_KEY environment variable must be set
- An auxiliary LLM provider (OPENROUTER_API_KEY or Nous Portal auth) (optional, for LLM tests)
"""
import pytest
pytestmark = pytest.mark.integration
import json
import asyncio
import sys
import os
import argparse
from datetime import datetime
from typing import List
# Import the web tools to test (updated path after moving tools/)
from tools.web_tools import (
web_search_tool,
web_extract_tool,
web_crawl_tool,
check_firecrawl_api_key,
check_web_api_key,
check_auxiliary_model,
get_debug_session_info,
_get_backend,
)
class Colors:
"""ANSI color codes for terminal output"""
HEADER = '\033[95m'
BLUE = '\033[94m'
CYAN = '\033[96m'
GREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def print_header(text: str):
"""Print a formatted header"""
print(f"\n{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}")
print(f"{Colors.HEADER}{Colors.BOLD}{text}{Colors.ENDC}")
print(f"{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}")
def print_section(text: str):
"""Print a formatted section header"""
print(f"\n{Colors.CYAN}{Colors.BOLD}📌 {text}{Colors.ENDC}")
print(f"{Colors.CYAN}{'-'*50}{Colors.ENDC}")
def print_success(text: str):
"""Print success message"""
print(f"{Colors.GREEN}{text}{Colors.ENDC}")
def print_error(text: str):
"""Print error message"""
print(f"{Colors.FAIL}{text}{Colors.ENDC}")
def print_warning(text: str):
"""Print warning message"""
print(f"{Colors.WARNING}⚠️ {text}{Colors.ENDC}")
def print_info(text: str, indent: int = 0):
"""Print info message"""
indent_str = " " * indent
print(f"{indent_str}{Colors.BLUE} {text}{Colors.ENDC}")
class WebToolsTester:
"""Test suite for web tools"""
def __init__(self, verbose: bool = False, test_llm: bool = True):
self.verbose = verbose
self.test_llm = test_llm
self.test_results = {
"passed": [],
"failed": [],
"skipped": []
}
self.start_time = None
self.end_time = None
def log_result(self, test_name: str, status: str, details: str = ""):
"""Log test result"""
result = {
"test": test_name,
"status": status,
"details": details,
"timestamp": datetime.now().isoformat()
}
if status == "passed":
self.test_results["passed"].append(result)
print_success(f"{test_name}: {details}" if details else test_name)
elif status == "failed":
self.test_results["failed"].append(result)
print_error(f"{test_name}: {details}" if details else test_name)
elif status == "skipped":
self.test_results["skipped"].append(result)
print_warning(f"{test_name} skipped: {details}" if details else f"{test_name} skipped")
def test_environment(self) -> bool:
"""Test environment setup and API keys"""
print_section("Environment Check")
# Check web backend API key (Parallel or Firecrawl)
if not check_web_api_key():
self.log_result("Web Backend API Key", "failed", "PARALLEL_API_KEY or FIRECRAWL_API_KEY not set")
return False
else:
backend = _get_backend()
self.log_result("Web Backend API Key", "passed", f"Using {backend} backend")
# Check auxiliary LLM provider (optional)
if not check_auxiliary_model():
self.log_result("Auxiliary LLM", "skipped", "No auxiliary LLM provider available (LLM tests will be skipped)")
self.test_llm = False
else:
self.log_result("Auxiliary LLM", "passed", "Found")
# Check debug mode
debug_info = get_debug_session_info()
if debug_info["enabled"]:
print_info(f"Debug mode enabled - Session: {debug_info['session_id']}")
print_info(f"Debug log: {debug_info['log_path']}")
return True
def test_web_search(self) -> List[str]:
"""Test web search functionality"""
print_section("Test 1: Web Search")
test_queries = [
("Python web scraping tutorial", 5),
("Firecrawl API documentation", 3),
("inflammatory arthritis symptoms treatment", 8) # Test medical query from your example
]
extracted_urls = []
for query, limit in test_queries:
try:
print(f"\n Testing search: '{query}' (limit={limit})")
if self.verbose:
print(f" Calling web_search_tool(query='{query}', limit={limit})")
# Perform search
result = web_search_tool(query, limit)
# Parse result
try:
data = json.loads(result)
except json.JSONDecodeError as e:
self.log_result(f"Search: {query[:30]}...", "failed", f"Invalid JSON: {e}")
if self.verbose:
print(f" Raw response (first 500 chars): {result[:500]}...")
continue
if "error" in data:
self.log_result(f"Search: {query[:30]}...", "failed", f"API error: {data['error']}")
continue
# Check structure
if "success" not in data or "data" not in data:
self.log_result(f"Search: {query[:30]}...", "failed", "Missing success or data fields")
if self.verbose:
print(f" Response keys: {list(data.keys())}")
continue
web_results = data.get("data", {}).get("web", [])
if not web_results:
self.log_result(f"Search: {query[:30]}...", "failed", "Empty web results array")
if self.verbose:
print(f" data.web content: {data.get('data', {}).get('web')}")
continue
# Validate each result
valid_results = 0
missing_fields = []
for i, result in enumerate(web_results):
required_fields = ["url", "title", "description"]
has_all_fields = all(key in result for key in required_fields)
if has_all_fields:
valid_results += 1
# Collect URLs for extraction test
if len(extracted_urls) < 3:
extracted_urls.append(result["url"])
if self.verbose:
print(f" Result {i+1}: ✓ {result['title'][:50]}...")
print(f" URL: {result['url'][:60]}...")
else:
missing = [f for f in required_fields if f not in result]
missing_fields.append(f"Result {i+1} missing: {missing}")
if self.verbose:
print(f" Result {i+1}: ✗ Missing fields: {missing}")
# Log results
if valid_results == len(web_results):
self.log_result(
f"Search: {query[:30]}...",
"passed",
f"All {valid_results} results valid"
)
else:
self.log_result(
f"Search: {query[:30]}...",
"failed",
f"Only {valid_results}/{len(web_results)} valid. Issues: {'; '.join(missing_fields[:3])}"
)
except Exception as e:
self.log_result(f"Search: {query[:30]}...", "failed", f"Exception: {type(e).__name__}: {str(e)}")
if self.verbose:
import traceback
print(f" Traceback: {traceback.format_exc()}")
if self.verbose and extracted_urls:
print(f"\n URLs collected for extraction test: {len(extracted_urls)}")
for url in extracted_urls:
print(f" - {url}")
return extracted_urls
async def test_web_extract(self, urls: List[str] = None):
"""Test web content extraction"""
print_section("Test 2: Web Extract (without LLM)")
# Use provided URLs or defaults
if not urls:
urls = [
"https://docs.firecrawl.dev/introduction",
"https://www.python.org/about/"
]
print(f" Using default URLs for testing")
else:
print(f" Using {len(urls)} URLs from search results")
# Test extraction
if urls:
try:
test_urls = urls[:2] # Test with max 2 URLs
print(f"\n Extracting content from {len(test_urls)} URL(s)...")
for url in test_urls:
print(f" - {url}")
if self.verbose:
print(f" Calling web_extract_tool(urls={test_urls}, format='markdown', use_llm_processing=False)")
result = await web_extract_tool(
test_urls,
format="markdown",
use_llm_processing=False
)
# Parse result
try:
data = json.loads(result)
except json.JSONDecodeError as e:
self.log_result("Extract (no LLM)", "failed", f"Invalid JSON: {e}")
if self.verbose:
print(f" Raw response (first 500 chars): {result[:500]}...")
return
if "error" in data:
self.log_result("Extract (no LLM)", "failed", f"API error: {data['error']}")
return
results = data.get("results", [])
if not results:
self.log_result("Extract (no LLM)", "failed", "No results in response")
if self.verbose:
print(f" Response keys: {list(data.keys())}")
return
# Validate each result
valid_results = 0
failed_results = 0
total_content_length = 0
extraction_details = []
for i, result in enumerate(results):
title = result.get("title", "No title")
content = result.get("content", "")
error = result.get("error")
if error:
failed_results += 1
extraction_details.append(f"Page {i+1}: ERROR - {error}")
if self.verbose:
print(f" Page {i+1}: ✗ Error - {error}")
elif content:
content_len = len(content)
total_content_length += content_len
valid_results += 1
extraction_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)")
if self.verbose:
print(f" Page {i+1}: ✓ {title[:50]}... - {content_len} characters")
print(f" First 100 chars: {content[:100]}...")
else:
extraction_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)")
if self.verbose:
print(f" Page {i+1}: ⚠ {title[:50]}... - Empty content")
# Log results
if valid_results > 0:
self.log_result(
"Extract (no LLM)",
"passed",
f"{valid_results}/{len(results)} pages extracted, {total_content_length} total chars"
)
else:
self.log_result(
"Extract (no LLM)",
"failed",
f"No valid content. {failed_results} errors, {len(results) - failed_results} empty"
)
if self.verbose:
print(f"\n Extraction details:")
for detail in extraction_details:
print(f" {detail}")
except Exception as e:
self.log_result("Extract (no LLM)", "failed", f"Exception: {type(e).__name__}: {str(e)}")
if self.verbose:
import traceback
print(f" Traceback: {traceback.format_exc()}")
async def test_web_extract_with_llm(self, urls: List[str] = None):
"""Test web extraction with LLM processing"""
print_section("Test 3: Web Extract (with Gemini LLM)")
if not self.test_llm:
self.log_result("Extract (with LLM)", "skipped", "LLM testing disabled")
return
# Use a URL likely to have substantial content
test_url = urls[0] if urls else "https://docs.firecrawl.dev/features/scrape"
try:
print(f"\n Extracting and processing: {test_url}")
result = await web_extract_tool(
[test_url],
format="markdown",
use_llm_processing=True,
min_length=1000 # Lower threshold for testing
)
data = json.loads(result)
if "error" in data:
self.log_result("Extract (with LLM)", "failed", data["error"])
return
results = data.get("results", [])
if not results:
self.log_result("Extract (with LLM)", "failed", "No results returned")
return
result = results[0]
content = result.get("content", "")
if content:
content_len = len(content)
# Check if content was actually processed (should be shorter than typical raw content)
if content_len > 0:
self.log_result(
"Extract (with LLM)",
"passed",
f"Content processed: {content_len} chars"
)
if self.verbose:
print(f"\n First 300 chars of processed content:")
print(f" {content[:300]}...")
else:
self.log_result("Extract (with LLM)", "failed", "No content after processing")
else:
self.log_result("Extract (with LLM)", "failed", "No content field in result")
except json.JSONDecodeError as e:
self.log_result("Extract (with LLM)", "failed", f"Invalid JSON: {e}")
except Exception as e:
self.log_result("Extract (with LLM)", "failed", str(e))
async def test_web_crawl(self):
"""Test web crawling functionality"""
print_section("Test 4: Web Crawl")
test_sites = [
("https://docs.firecrawl.dev", None, 2), # Test docs site
("https://firecrawl.dev", None, 3), # Test main site
]
for url, instructions, expected_min_pages in test_sites:
try:
print(f"\n Testing crawl of: {url}")
if instructions:
print(f" Instructions: {instructions}")
else:
print(f" No instructions (general crawl)")
print(f" Expected minimum pages: {expected_min_pages}")
# Show what's being called
if self.verbose:
print(f" Calling web_crawl_tool(url='{url}', instructions={instructions}, use_llm_processing=False)")
result = await web_crawl_tool(
url,
instructions=instructions,
use_llm_processing=False # Disable LLM for faster testing
)
# Check if result is valid JSON
try:
data = json.loads(result)
except json.JSONDecodeError as e:
self.log_result(f"Crawl: {url}", "failed", f"Invalid JSON response: {e}")
if self.verbose:
print(f" Raw response (first 500 chars): {result[:500]}...")
continue
# Check for errors
if "error" in data:
self.log_result(f"Crawl: {url}", "failed", f"API error: {data['error']}")
continue
# Get results
results = data.get("results", [])
if not results:
self.log_result(f"Crawl: {url}", "failed", "No pages in results array")
if self.verbose:
print(f" Full response: {json.dumps(data, indent=2)[:1000]}...")
continue
# Analyze pages
valid_pages = 0
empty_pages = 0
total_content = 0
page_details = []
for i, page in enumerate(results):
content = page.get("content", "")
title = page.get("title", "Untitled")
error = page.get("error")
if error:
page_details.append(f"Page {i+1}: ERROR - {error}")
elif content:
valid_pages += 1
content_len = len(content)
total_content += content_len
page_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)")
else:
empty_pages += 1
page_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)")
# Show detailed results if verbose
if self.verbose:
print(f"\n Crawl Results:")
print(f" Total pages returned: {len(results)}")
print(f" Valid pages (with content): {valid_pages}")
print(f" Empty pages: {empty_pages}")
print(f" Total content size: {total_content} characters")
print(f"\n Page Details:")
for detail in page_details[:10]: # Show first 10 pages
print(f" - {detail}")
if len(page_details) > 10:
print(f" ... and {len(page_details) - 10} more pages")
# Determine pass/fail
if valid_pages >= expected_min_pages:
self.log_result(
f"Crawl: {url}",
"passed",
f"{valid_pages}/{len(results)} valid pages, {total_content} chars total"
)
else:
self.log_result(
f"Crawl: {url}",
"failed",
f"Only {valid_pages} valid pages (expected >= {expected_min_pages}), {empty_pages} empty, {len(results)} total"
)
except Exception as e:
self.log_result(f"Crawl: {url}", "failed", f"Exception: {type(e).__name__}: {str(e)}")
if self.verbose:
import traceback
print(f" Traceback:")
print(" " + "\n ".join(traceback.format_exc().split("\n")))
async def run_all_tests(self):
"""Run all tests"""
self.start_time = datetime.now()
print_header("WEB TOOLS TEST SUITE")
print(f"Started at: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
# Test environment
if not self.test_environment():
print_error("\nCannot proceed without required API keys!")
return False
# Test search and collect URLs
urls = self.test_web_search()
# Test extraction
await self.test_web_extract(urls if urls else None)
# Test extraction with LLM
if self.test_llm:
await self.test_web_extract_with_llm(urls if urls else None)
# Test crawling
await self.test_web_crawl()
# Print summary
self.end_time = datetime.now()
duration = (self.end_time - self.start_time).total_seconds()
print_header("TEST SUMMARY")
print(f"Duration: {duration:.2f} seconds")
print(f"\n{Colors.GREEN}Passed: {len(self.test_results['passed'])}{Colors.ENDC}")
print(f"{Colors.FAIL}Failed: {len(self.test_results['failed'])}{Colors.ENDC}")
print(f"{Colors.WARNING}Skipped: {len(self.test_results['skipped'])}{Colors.ENDC}")
# List failed tests
if self.test_results["failed"]:
print(f"\n{Colors.FAIL}{Colors.BOLD}Failed Tests:{Colors.ENDC}")
for test in self.test_results["failed"]:
print(f" - {test['test']}: {test['details']}")
# Save results to file
self.save_results()
return len(self.test_results["failed"]) == 0
def save_results(self):
"""Save test results to a JSON file"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"test_results_web_tools_{timestamp}.json"
results = {
"test_suite": "Web Tools",
"start_time": self.start_time.isoformat() if self.start_time else None,
"end_time": self.end_time.isoformat() if self.end_time else None,
"duration_seconds": (self.end_time - self.start_time).total_seconds() if self.start_time and self.end_time else None,
"summary": {
"passed": len(self.test_results["passed"]),
"failed": len(self.test_results["failed"]),
"skipped": len(self.test_results["skipped"])
},
"results": self.test_results,
"environment": {
"web_backend": _get_backend() if check_web_api_key() else None,
"firecrawl_api_key": check_firecrawl_api_key(),
"parallel_api_key": bool(os.getenv("PARALLEL_API_KEY")),
"auxiliary_model": check_auxiliary_model(),
"debug_mode": get_debug_session_info()["enabled"]
}
}
try:
with open(filename, 'w') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print_info(f"Test results saved to: {filename}")
except Exception as e:
print_warning(f"Failed to save results: {e}")
async def main():
"""Main entry point"""
parser = argparse.ArgumentParser(description="Test Web Tools Module")
parser.add_argument("--no-llm", action="store_true", help="Skip LLM processing tests")
parser.add_argument("--verbose", "-v", action="store_true", help="Show detailed output")
parser.add_argument("--debug", action="store_true", help="Enable debug mode for web tools")
args = parser.parse_args()
# Set debug mode if requested
if args.debug:
os.environ["WEB_TOOLS_DEBUG"] = "true"
print_info("Debug mode enabled for web tools")
# Create tester
tester = WebToolsTester(
verbose=args.verbose,
test_llm=not args.no_llm
)
# Run tests
success = await tester.run_all_tests()
# Exit with appropriate code
sys.exit(0 if success else 1)
if __name__ == "__main__":
asyncio.run(main())