master/test/test_sandbox_usecase.py

589 lines
18 KiB
Python

import threading
from datetime import UTC, datetime, timedelta
from domain.sandbox import SandboxSession, SandboxStatus
from repository.sandbox_lock import ProcessLocalSandboxLifecycleLocker
from repository.sandbox_session import InMemorySandboxSessionRepository
from usecase.sandbox import CleanupExpiredSandboxes, CreateSandbox, CreateSandboxCommand
class FakeClock:
def __init__(self, now: datetime) -> None:
self._now = now
def now(self) -> datetime:
return self._now
class FakeLogger:
def __init__(self) -> None:
self.messages: list[
tuple[str, str, dict[str, str | int | float | bool] | None]
] = []
def debug(self, message: str, attrs=None) -> None:
self.messages.append(('debug', message, attrs))
def info(self, message: str, attrs=None) -> None:
self.messages.append(('info', message, attrs))
def warning(self, message: str, attrs=None) -> None:
self.messages.append(('warning', message, attrs))
def error(self, message: str, attrs=None) -> None:
self.messages.append(('error', message, attrs))
class FakeLockContext:
def __enter__(self) -> None:
return None
def __exit__(self, exc_type, exc, traceback) -> None:
return None
class FakeLocker:
def __init__(self) -> None:
self.chat_ids: list[str] = []
def lock(self, chat_id: str) -> FakeLockContext:
self.chat_ids.append(chat_id)
return FakeLockContext()
class TrackingLockContext:
def __init__(
self,
locker: 'TrackingLocker',
chat_id: str,
inner_context,
) -> None:
self._locker = locker
self._chat_id = chat_id
self._inner_context = inner_context
def __enter__(self) -> None:
with self._locker._state_lock:
self._locker.chat_ids.append(self._chat_id)
self._locker._attempts += 1
if self._locker._attempts == 2:
self._locker.second_attempted.set()
self._inner_context.__enter__()
def __exit__(self, exc_type, exc, traceback) -> bool | None:
return self._inner_context.__exit__(exc_type, exc, traceback)
class TrackingLocker:
def __init__(self) -> None:
self._locker = ProcessLocalSandboxLifecycleLocker()
self._state_lock = threading.Lock()
self._attempts = 0
self.second_attempted = threading.Event()
self.chat_ids: list[str] = []
def lock(self, chat_id: str) -> TrackingLockContext:
return TrackingLockContext(self, chat_id, self._locker.lock(chat_id))
class BlockingCreateRuntime:
def __init__(self) -> None:
self.create_calls: list[dict[str, object]] = []
self.stop_calls: list[str] = []
self.create_started = threading.Event()
self.allow_create = threading.Event()
def create(
self,
*,
session_id: str,
chat_id: str,
created_at: datetime,
expires_at: datetime,
) -> SandboxSession:
self.create_calls.append(
{
'session_id': session_id,
'chat_id': chat_id,
'created_at': created_at,
'expires_at': expires_at,
}
)
self.create_started.set()
assert self.allow_create.wait(timeout=1)
return SandboxSession(
session_id=session_id,
chat_id=chat_id,
container_id=f'container-{session_id}',
status=SandboxStatus.RUNNING,
created_at=created_at,
expires_at=expires_at,
)
def stop(self, container_id: str) -> None:
self.stop_calls.append(container_id)
class StaleSnapshotRepository(InMemorySandboxSessionRepository):
def __init__(self, snapshot: SandboxSession) -> None:
super().__init__()
self._snapshot = snapshot
def list_expired(self, now: datetime) -> list[SandboxSession]:
return [self._snapshot]
class FakeRuntime:
def __init__(self) -> None:
self.create_calls: list[dict[str, object]] = []
self.stop_calls: list[str] = []
def create(
self,
*,
session_id: str,
chat_id: str,
created_at: datetime,
expires_at: datetime,
) -> SandboxSession:
self.create_calls.append(
{
'session_id': session_id,
'chat_id': chat_id,
'created_at': created_at,
'expires_at': expires_at,
}
)
return SandboxSession(
session_id=session_id,
chat_id=chat_id,
container_id=f'container-{session_id}',
status=SandboxStatus.RUNNING,
created_at=created_at,
expires_at=expires_at,
)
def stop(self, container_id: str) -> None:
self.stop_calls.append(container_id)
class FailingStopRuntime(FakeRuntime):
def __init__(self, failing_container_id: str) -> None:
super().__init__()
self._failing_container_id = failing_container_id
def stop(self, container_id: str) -> None:
self.stop_calls.append(container_id)
if container_id == self._failing_container_id:
raise RuntimeError('stop_failed')
def test_create_sandbox_reuses_active_session_when_not_expired() -> None:
now = datetime(2026, 4, 2, 12, 0, tzinfo=UTC)
session = SandboxSession(
session_id='session-1',
chat_id='chat-1',
container_id='container-1',
status=SandboxStatus.RUNNING,
created_at=now - timedelta(minutes=1),
expires_at=now + timedelta(minutes=4),
)
repository = InMemorySandboxSessionRepository()
repository.save(session)
runtime = FakeRuntime()
logger = FakeLogger()
locker = FakeLocker()
usecase = CreateSandbox(
repository=repository,
locker=locker,
runtime=runtime,
clock=FakeClock(now),
logger=logger,
ttl=timedelta(minutes=5),
)
result = usecase.execute(CreateSandboxCommand(chat_id='chat-1'))
assert result == session
assert runtime.create_calls == []
assert runtime.stop_calls == []
assert repository.get_active_by_chat_id('chat-1') == session
assert locker.chat_ids == ['chat-1']
assert logger.messages == [
(
'info',
'sandbox_reused',
{
'chat_id': 'chat-1',
'session_id': 'session-1',
'container_id': 'container-1',
},
)
]
def test_create_sandbox_replaces_expired_session_and_creates_new_one(
monkeypatch,
) -> None:
now = datetime(2026, 4, 2, 12, 0, tzinfo=UTC)
expired_session = SandboxSession(
session_id='session-old',
chat_id='chat-1',
container_id='container-old',
status=SandboxStatus.RUNNING,
created_at=now - timedelta(minutes=10),
expires_at=now,
)
repository = InMemorySandboxSessionRepository()
repository.save(expired_session)
runtime = FakeRuntime()
logger = FakeLogger()
locker = FakeLocker()
usecase = CreateSandbox(
repository=repository,
locker=locker,
runtime=runtime,
clock=FakeClock(now),
logger=logger,
ttl=timedelta(minutes=5),
)
monkeypatch.setattr('usecase.sandbox._new_session_id', lambda: 'session-new')
result = usecase.execute(CreateSandboxCommand(chat_id='chat-1'))
assert runtime.stop_calls == ['container-old']
assert runtime.create_calls == [
{
'session_id': 'session-new',
'chat_id': 'chat-1',
'created_at': now,
'expires_at': now + timedelta(minutes=5),
}
]
assert result == SandboxSession(
session_id='session-new',
chat_id='chat-1',
container_id='container-session-new',
status=SandboxStatus.RUNNING,
created_at=now,
expires_at=now + timedelta(minutes=5),
)
assert repository.get_active_by_chat_id('chat-1') == result
assert locker.chat_ids == ['chat-1']
assert logger.messages == [
(
'info',
'sandbox_replaced',
{
'chat_id': 'chat-1',
'session_id': 'session-old',
'container_id': 'container-old',
},
),
(
'info',
'sandbox_created',
{
'chat_id': 'chat-1',
'session_id': 'session-new',
'container_id': 'container-session-new',
},
),
]
def test_create_sandbox_creates_new_session_when_none_exists() -> None:
now = datetime(2026, 4, 2, 12, 0, tzinfo=UTC)
repository = InMemorySandboxSessionRepository()
runtime = FakeRuntime()
logger = FakeLogger()
locker = FakeLocker()
usecase = CreateSandbox(
repository=repository,
locker=locker,
runtime=runtime,
clock=FakeClock(now),
logger=logger,
ttl=timedelta(minutes=5),
)
result = usecase.execute(CreateSandboxCommand(chat_id='chat-1'))
assert result.chat_id == 'chat-1'
assert result.container_id == f'container-{result.session_id}'
assert result.status is SandboxStatus.RUNNING
assert result.created_at == now
assert result.expires_at == now + timedelta(minutes=5)
assert len(runtime.create_calls) == 1
assert runtime.create_calls[0] == {
'session_id': result.session_id,
'chat_id': 'chat-1',
'created_at': now,
'expires_at': now + timedelta(minutes=5),
}
assert runtime.stop_calls == []
assert repository.get_active_by_chat_id('chat-1') == result
assert locker.chat_ids == ['chat-1']
assert logger.messages == [
(
'info',
'sandbox_created',
{
'chat_id': 'chat-1',
'session_id': result.session_id,
'container_id': result.container_id,
},
)
]
def test_create_sandbox_serializes_duplicate_concurrent_create_for_chat_id(
monkeypatch,
) -> None:
now = datetime(2026, 4, 2, 12, 0, tzinfo=UTC)
repository = InMemorySandboxSessionRepository()
runtime = BlockingCreateRuntime()
logger = FakeLogger()
locker = TrackingLocker()
usecase = CreateSandbox(
repository=repository,
locker=locker,
runtime=runtime,
clock=FakeClock(now),
logger=logger,
ttl=timedelta(minutes=5),
)
monkeypatch.setattr('usecase.sandbox._new_session_id', lambda: 'session-new')
results: list[SandboxSession | None] = [None, None]
errors: list[Exception] = []
def run_create(index: int) -> None:
try:
results[index] = usecase.execute(CreateSandboxCommand(chat_id='chat-1'))
except Exception as exc:
errors.append(exc)
first_thread = threading.Thread(target=run_create, args=(0,))
second_thread = threading.Thread(target=run_create, args=(1,))
first_thread.start()
assert runtime.create_started.wait(timeout=1)
second_thread.start()
assert locker.second_attempted.wait(timeout=1)
assert len(runtime.create_calls) == 1
runtime.allow_create.set()
first_thread.join(timeout=1)
second_thread.join(timeout=1)
assert errors == []
assert results[0] == results[1]
assert results[0] == SandboxSession(
session_id='session-new',
chat_id='chat-1',
container_id='container-session-new',
status=SandboxStatus.RUNNING,
created_at=now,
expires_at=now + timedelta(minutes=5),
)
assert len(runtime.create_calls) == 1
assert runtime.stop_calls == []
assert repository.get_active_by_chat_id('chat-1') == results[0]
assert locker.chat_ids == ['chat-1', 'chat-1']
assert logger.messages == [
(
'info',
'sandbox_created',
{
'chat_id': 'chat-1',
'session_id': 'session-new',
'container_id': 'container-session-new',
},
),
(
'info',
'sandbox_reused',
{
'chat_id': 'chat-1',
'session_id': 'session-new',
'container_id': 'container-session-new',
},
),
]
def test_cleanup_expired_sandboxes_stops_and_deletes_only_expired_sessions() -> None:
now = datetime(2026, 4, 2, 12, 0, tzinfo=UTC)
expired_session = SandboxSession(
session_id='session-expired',
chat_id='chat-expired',
container_id='container-expired',
status=SandboxStatus.RUNNING,
created_at=now - timedelta(minutes=10),
expires_at=now - timedelta(seconds=1),
)
boundary_session = SandboxSession(
session_id='session-boundary',
chat_id='chat-boundary',
container_id='container-boundary',
status=SandboxStatus.RUNNING,
created_at=now - timedelta(minutes=5),
expires_at=now,
)
active_session = SandboxSession(
session_id='session-active',
chat_id='chat-active',
container_id='container-active',
status=SandboxStatus.RUNNING,
created_at=now - timedelta(minutes=1),
expires_at=now + timedelta(minutes=5),
)
repository = InMemorySandboxSessionRepository()
repository.save(expired_session)
repository.save(boundary_session)
repository.save(active_session)
runtime = FakeRuntime()
logger = FakeLogger()
locker = FakeLocker()
usecase = CleanupExpiredSandboxes(
repository=repository,
locker=locker,
runtime=runtime,
clock=FakeClock(now),
logger=logger,
)
result = usecase.execute()
assert result == [expired_session, boundary_session]
assert runtime.stop_calls == ['container-expired', 'container-boundary']
assert repository.get_active_by_chat_id('chat-expired') is None
assert repository.get_active_by_chat_id('chat-boundary') is None
assert repository.get_active_by_chat_id('chat-active') == active_session
assert locker.chat_ids == ['chat-expired', 'chat-boundary']
assert logger.messages == [
(
'info',
'sandbox_cleaned',
{
'chat_id': 'chat-expired',
'session_id': 'session-expired',
'container_id': 'container-expired',
},
),
(
'info',
'sandbox_cleaned',
{
'chat_id': 'chat-boundary',
'session_id': 'session-boundary',
'container_id': 'container-boundary',
},
),
]
def test_cleanup_expired_sandboxes_skips_replaced_session_from_stale_snapshot() -> None:
now = datetime(2026, 4, 2, 12, 0, tzinfo=UTC)
expired_snapshot = SandboxSession(
session_id='session-expired',
chat_id='chat-1',
container_id='container-expired',
status=SandboxStatus.RUNNING,
created_at=now - timedelta(minutes=10),
expires_at=now - timedelta(seconds=1),
)
replacement_session = SandboxSession(
session_id='session-new',
chat_id='chat-1',
container_id='container-new',
status=SandboxStatus.RUNNING,
created_at=now - timedelta(seconds=30),
expires_at=now + timedelta(minutes=5),
)
repository = StaleSnapshotRepository(expired_snapshot)
repository.save(replacement_session)
runtime = FakeRuntime()
logger = FakeLogger()
locker = FakeLocker()
usecase = CleanupExpiredSandboxes(
repository=repository,
locker=locker,
runtime=runtime,
clock=FakeClock(now),
logger=logger,
)
result = usecase.execute()
assert result == []
assert runtime.stop_calls == []
assert repository.get_active_by_chat_id('chat-1') == replacement_session
assert locker.chat_ids == ['chat-1']
assert logger.messages == []
def test_cleanup_expired_sandboxes_continues_after_stop_failure() -> None:
now = datetime(2026, 4, 2, 12, 0, tzinfo=UTC)
failing_session = SandboxSession(
session_id='session-fail',
chat_id='chat-fail',
container_id='container-fail',
status=SandboxStatus.RUNNING,
created_at=now - timedelta(minutes=10),
expires_at=now - timedelta(minutes=1),
)
cleaned_session = SandboxSession(
session_id='session-clean',
chat_id='chat-clean',
container_id='container-clean',
status=SandboxStatus.RUNNING,
created_at=now - timedelta(minutes=9),
expires_at=now - timedelta(seconds=1),
)
repository = InMemorySandboxSessionRepository()
repository.save(failing_session)
repository.save(cleaned_session)
runtime = FailingStopRuntime('container-fail')
logger = FakeLogger()
locker = FakeLocker()
usecase = CleanupExpiredSandboxes(
repository=repository,
locker=locker,
runtime=runtime,
clock=FakeClock(now),
logger=logger,
)
result = usecase.execute()
assert result == [cleaned_session]
assert runtime.stop_calls == ['container-fail', 'container-clean']
assert repository.get_active_by_chat_id('chat-fail') == failing_session
assert repository.get_active_by_chat_id('chat-clean') is None
assert locker.chat_ids == ['chat-fail', 'chat-clean']
assert logger.messages == [
(
'error',
'sandbox_clean_failed',
{
'chat_id': 'chat-fail',
'session_id': 'session-fail',
'container_id': 'container-fail',
'error': 'RuntimeError',
},
),
(
'info',
'sandbox_cleaned',
{
'chat_id': 'chat-clean',
'session_id': 'session-clean',
'container_id': 'container-clean',
},
),
]