Normalize chat ids to a single UUID form so locks, repository keys, and mount paths cannot diverge through path-like aliases.
159 lines
4.9 KiB
Python
159 lines
4.9 KiB
Python
from dataclasses import dataclass
|
|
from datetime import timedelta
|
|
from uuid import UUID, uuid4
|
|
|
|
from domain.sandbox import SandboxSession
|
|
from usecase.interface import (
|
|
Clock,
|
|
Logger,
|
|
SandboxLifecycleLocker,
|
|
SandboxRuntime,
|
|
SandboxSessionRepository,
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class CreateSandboxCommand:
|
|
chat_id: str
|
|
|
|
|
|
class CreateSandbox:
|
|
def __init__(
|
|
self,
|
|
repository: SandboxSessionRepository,
|
|
locker: SandboxLifecycleLocker,
|
|
runtime: SandboxRuntime,
|
|
clock: Clock,
|
|
logger: Logger,
|
|
ttl: timedelta,
|
|
) -> None:
|
|
self._repository = repository
|
|
self._locker = locker
|
|
self._runtime = runtime
|
|
self._clock = clock
|
|
self._logger = logger
|
|
self._ttl = ttl
|
|
|
|
def execute(self, command: CreateSandboxCommand) -> SandboxSession:
|
|
chat_id = _canonical_chat_id(command.chat_id)
|
|
|
|
with self._locker.lock(chat_id):
|
|
session = self._repository.get_active_by_chat_id(chat_id)
|
|
now = self._clock.now()
|
|
|
|
if session is not None and session.expires_at > now:
|
|
self._logger.info(
|
|
'sandbox_reused',
|
|
attrs={
|
|
'chat_id': chat_id,
|
|
'session_id': session.session_id,
|
|
'container_id': session.container_id,
|
|
},
|
|
)
|
|
return session
|
|
|
|
if session is not None:
|
|
self._logger.info(
|
|
'sandbox_replaced',
|
|
attrs={
|
|
'chat_id': chat_id,
|
|
'session_id': session.session_id,
|
|
'container_id': session.container_id,
|
|
},
|
|
)
|
|
self._runtime.stop(session.container_id)
|
|
self._repository.delete(session.session_id)
|
|
|
|
created_at = self._clock.now()
|
|
expires_at = created_at + self._ttl
|
|
new_session = self._runtime.create(
|
|
session_id=_new_session_id(),
|
|
chat_id=chat_id,
|
|
created_at=created_at,
|
|
expires_at=expires_at,
|
|
)
|
|
self._repository.save(new_session)
|
|
self._logger.info(
|
|
'sandbox_created',
|
|
attrs={
|
|
'chat_id': chat_id,
|
|
'session_id': new_session.session_id,
|
|
'container_id': new_session.container_id,
|
|
},
|
|
)
|
|
return new_session
|
|
|
|
|
|
class CleanupExpiredSandboxes:
|
|
def __init__(
|
|
self,
|
|
repository: SandboxSessionRepository,
|
|
locker: SandboxLifecycleLocker,
|
|
runtime: SandboxRuntime,
|
|
clock: Clock,
|
|
logger: Logger,
|
|
) -> None:
|
|
self._repository = repository
|
|
self._locker = locker
|
|
self._runtime = runtime
|
|
self._clock = clock
|
|
self._logger = logger
|
|
|
|
def execute(self) -> list[SandboxSession]:
|
|
expired_sessions = self._repository.list_expired(self._clock.now())
|
|
cleaned_sessions: list[SandboxSession] = []
|
|
|
|
for session in expired_sessions:
|
|
try:
|
|
cleaned_session = self._cleanup_session(session)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
'sandbox_clean_failed',
|
|
attrs={
|
|
'chat_id': session.chat_id,
|
|
'session_id': session.session_id,
|
|
'container_id': session.container_id,
|
|
'error': type(exc).__name__,
|
|
},
|
|
)
|
|
continue
|
|
|
|
if cleaned_session is None:
|
|
continue
|
|
|
|
cleaned_sessions.append(cleaned_session)
|
|
self._logger.info(
|
|
'sandbox_cleaned',
|
|
attrs={
|
|
'chat_id': cleaned_session.chat_id,
|
|
'session_id': cleaned_session.session_id,
|
|
'container_id': cleaned_session.container_id,
|
|
},
|
|
)
|
|
|
|
return cleaned_sessions
|
|
|
|
def _cleanup_session(self, session: SandboxSession) -> SandboxSession | None:
|
|
with self._locker.lock(session.chat_id):
|
|
current_session = self._repository.get_active_by_chat_id(session.chat_id)
|
|
now = self._clock.now()
|
|
if current_session is None:
|
|
return None
|
|
|
|
if current_session.session_id != session.session_id:
|
|
return None
|
|
|
|
if current_session.expires_at > now:
|
|
return None
|
|
|
|
self._runtime.stop(current_session.container_id)
|
|
self._repository.delete(current_session.session_id)
|
|
return current_session
|
|
|
|
|
|
def _new_session_id() -> str:
|
|
return uuid4().hex
|
|
|
|
|
|
def _canonical_chat_id(chat_id: str) -> str:
|
|
return str(UUID(str(chat_id).strip()))
|