diff --git a/adapter/di/container.py b/adapter/di/container.py index 592cf6e..55f95a0 100644 --- a/adapter/di/container.py +++ b/adapter/di/container.py @@ -12,6 +12,7 @@ from adapter.docker.runtime import DockerSandboxRuntime from adapter.observability.factory import build_observability from adapter.observability.runtime import ObservabilityRuntime from domain.user import User +from repository.sandbox_lock import ProcessLocalSandboxLifecycleLocker from repository.sandbox_session import InMemorySandboxSessionRepository from repository.user import InMemoryUserRepository from usecase.interface import Clock @@ -85,6 +86,7 @@ def build_container( observability.tracer, [User(id='123', email='aza@gglamer.ru', name='gglamer')] ) sandbox_repository = InMemorySandboxSessionRepository() + sandbox_locker = ProcessLocalSandboxLifecycleLocker() sandbox_runtime = DockerSandboxRuntime(app_config.sandbox, docker_client) repositories = AppRepositories( @@ -99,6 +101,7 @@ def build_container( ), create_sandbox=CreateSandbox( repository=sandbox_repository, + locker=sandbox_locker, runtime=sandbox_runtime, clock=clock, logger=observability.logger, @@ -106,6 +109,7 @@ def build_container( ), cleanup_expired_sandboxes=CleanupExpiredSandboxes( repository=sandbox_repository, + locker=sandbox_locker, runtime=sandbox_runtime, clock=clock, logger=observability.logger, diff --git a/repository/sandbox_lock.py b/repository/sandbox_lock.py new file mode 100644 index 0000000..704aeae --- /dev/null +++ b/repository/sandbox_lock.py @@ -0,0 +1,43 @@ +import threading +from types import TracebackType +from typing import Protocol + +from usecase.interface import LockContext, SandboxLifecycleLocker + + +class _SyncLock(Protocol): + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: ... + + def release(self) -> None: ... + + +class _ChatLock(LockContext): + def __init__(self, lock: _SyncLock) -> None: + self._lock = lock + + def __enter__(self) -> None: + self._lock.acquire() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + self._lock.release() + return None + + +class ProcessLocalSandboxLifecycleLocker(SandboxLifecycleLocker): + def __init__(self) -> None: + self._registry_lock = threading.Lock() + self._locks_by_chat_id: dict[str, _SyncLock] = {} + + def lock(self, chat_id: str) -> LockContext: + with self._registry_lock: + lock = self._locks_by_chat_id.get(chat_id) + if lock is None: + lock = threading.Lock() + self._locks_by_chat_id[chat_id] = lock + + return _ChatLock(lock) diff --git a/repository/sandbox_session.py b/repository/sandbox_session.py index 9b23cd7..6707d0c 100644 --- a/repository/sandbox_session.py +++ b/repository/sandbox_session.py @@ -1,3 +1,4 @@ +import threading from datetime import datetime from domain.sandbox import SandboxSession @@ -7,22 +8,27 @@ from usecase.interface import SandboxSessionRepository class InMemorySandboxSessionRepository(SandboxSessionRepository): def __init__(self) -> None: self._sessions_by_chat_id: dict[str, SandboxSession] = {} + self._lock = threading.Lock() def get_active_by_chat_id(self, chat_id: str) -> SandboxSession | None: - return self._sessions_by_chat_id.get(chat_id) + with self._lock: + return self._sessions_by_chat_id.get(chat_id) def list_expired(self, now: datetime) -> list[SandboxSession]: - return [ - session - for session in self._sessions_by_chat_id.values() - if session.expires_at <= now - ] + with self._lock: + return [ + session + for session in self._sessions_by_chat_id.values() + if session.expires_at <= now + ] def save(self, session: SandboxSession) -> None: - self._sessions_by_chat_id[session.chat_id] = session + with self._lock: + self._sessions_by_chat_id[session.chat_id] = session def delete(self, session_id: str) -> None: - for chat_id, session in tuple(self._sessions_by_chat_id.items()): - if session.session_id == session_id: - del self._sessions_by_chat_id[chat_id] - return + with self._lock: + for chat_id, session in tuple(self._sessions_by_chat_id.items()): + if session.session_id == session_id: + del self._sessions_by_chat_id[chat_id] + return diff --git a/tasks.md b/tasks.md index f624b35..75b4ce8 100644 --- a/tasks.md +++ b/tasks.md @@ -120,7 +120,7 @@ ### M09. Сериализация lifecycle sandbox по `chat_id` - Субагент: `feature-developer` -- Статус: pending +- Статус: completed - Зависимости: `M08` - Commit required: no - Scope: убрать гонки между параллельными `create` и cleanup для одного `chat_id` diff --git a/test/test_sandbox_usecase.py b/test/test_sandbox_usecase.py index b050b69..a58dd5c 100644 --- a/test/test_sandbox_usecase.py +++ b/test/test_sandbox_usecase.py @@ -32,6 +32,23 @@ class FakeLogger: self.messages.append(('error', message, attrs)) +class FakeLockContext: + def __enter__(self) -> None: + return None + + def __exit__(self, exc_type, exc, traceback) -> None: + return None + + +class FakeLocker: + def __init__(self) -> None: + self.chat_ids: list[str] = [] + + def lock(self, chat_id: str) -> FakeLockContext: + self.chat_ids.append(chat_id) + return FakeLockContext() + + class FakeRuntime: def __init__(self) -> None: self.create_calls: list[dict[str, object]] = [] @@ -80,8 +97,10 @@ def test_create_sandbox_reuses_active_session_when_not_expired() -> None: repository.save(session) runtime = FakeRuntime() logger = FakeLogger() + locker = FakeLocker() usecase = CreateSandbox( repository=repository, + locker=locker, runtime=runtime, clock=FakeClock(now), logger=logger, @@ -94,6 +113,7 @@ def test_create_sandbox_reuses_active_session_when_not_expired() -> None: assert runtime.create_calls == [] assert runtime.stop_calls == [] assert repository.get_active_by_chat_id('chat-1') == session + assert locker.chat_ids == ['chat-1'] assert logger.messages == [ ( 'info', @@ -123,8 +143,10 @@ def test_create_sandbox_replaces_expired_session_and_creates_new_one( repository.save(expired_session) runtime = FakeRuntime() logger = FakeLogger() + locker = FakeLocker() usecase = CreateSandbox( repository=repository, + locker=locker, runtime=runtime, clock=FakeClock(now), logger=logger, @@ -152,6 +174,7 @@ def test_create_sandbox_replaces_expired_session_and_creates_new_one( expires_at=now + timedelta(minutes=5), ) assert repository.get_active_by_chat_id('chat-1') == result + assert locker.chat_ids == ['chat-1'] assert logger.messages == [ ( 'info', @@ -179,8 +202,10 @@ def test_create_sandbox_creates_new_session_when_none_exists() -> None: repository = InMemorySandboxSessionRepository() runtime = FakeRuntime() logger = FakeLogger() + locker = FakeLocker() usecase = CreateSandbox( repository=repository, + locker=locker, runtime=runtime, clock=FakeClock(now), logger=logger, @@ -203,6 +228,7 @@ def test_create_sandbox_creates_new_session_when_none_exists() -> None: } assert runtime.stop_calls == [] assert repository.get_active_by_chat_id('chat-1') == result + assert locker.chat_ids == ['chat-1'] assert logger.messages == [ ( 'info', @@ -248,8 +274,10 @@ def test_cleanup_expired_sandboxes_stops_and_deletes_only_expired_sessions() -> repository.save(active_session) runtime = FakeRuntime() logger = FakeLogger() + locker = FakeLocker() usecase = CleanupExpiredSandboxes( repository=repository, + locker=locker, runtime=runtime, clock=FakeClock(now), logger=logger, @@ -262,6 +290,7 @@ def test_cleanup_expired_sandboxes_stops_and_deletes_only_expired_sessions() -> assert repository.get_active_by_chat_id('chat-expired') is None assert repository.get_active_by_chat_id('chat-boundary') is None assert repository.get_active_by_chat_id('chat-active') == active_session + assert locker.chat_ids == ['chat-expired', 'chat-boundary'] assert logger.messages == [ ( 'info', diff --git a/usecase/interface.py b/usecase/interface.py index 0c8bcaa..0c0e321 100644 --- a/usecase/interface.py +++ b/usecase/interface.py @@ -28,6 +28,21 @@ class SandboxSessionRepository(Protocol): def delete(self, session_id: str) -> None: ... +class LockContext(Protocol): + def __enter__(self) -> None: ... + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: ... + + +class SandboxLifecycleLocker(Protocol): + def lock(self, chat_id: str) -> LockContext: ... + + class SandboxRuntime(Protocol): def create( self, diff --git a/usecase/sandbox.py b/usecase/sandbox.py index 65740ef..0cb39e8 100644 --- a/usecase/sandbox.py +++ b/usecase/sandbox.py @@ -3,7 +3,13 @@ from datetime import timedelta from uuid import uuid4 from domain.sandbox import SandboxSession -from usecase.interface import Clock, Logger, SandboxRuntime, SandboxSessionRepository +from usecase.interface import ( + Clock, + Logger, + SandboxLifecycleLocker, + SandboxRuntime, + SandboxSessionRepository, +) @dataclass(frozen=True, slots=True) @@ -15,93 +21,112 @@ class CreateSandbox: def __init__( self, repository: SandboxSessionRepository, + locker: SandboxLifecycleLocker, runtime: SandboxRuntime, clock: Clock, logger: Logger, ttl: timedelta, ) -> None: self._repository = repository + self._locker = locker self._runtime = runtime self._clock = clock self._logger = logger self._ttl = ttl def execute(self, command: CreateSandboxCommand) -> SandboxSession: - now = self._clock.now() - session = self._repository.get_active_by_chat_id(command.chat_id) + with self._locker.lock(command.chat_id): + session = self._repository.get_active_by_chat_id(command.chat_id) + now = self._clock.now() - if session is not None and session.expires_at > now: + if session is not None and session.expires_at > now: + self._logger.info( + 'sandbox_reused', + attrs={ + 'chat_id': command.chat_id, + 'session_id': session.session_id, + 'container_id': session.container_id, + }, + ) + return session + + if session is not None: + self._logger.info( + 'sandbox_replaced', + attrs={ + 'chat_id': command.chat_id, + 'session_id': session.session_id, + 'container_id': session.container_id, + }, + ) + self._runtime.stop(session.container_id) + self._repository.delete(session.session_id) + + created_at = self._clock.now() + expires_at = created_at + self._ttl + new_session = self._runtime.create( + session_id=_new_session_id(), + chat_id=command.chat_id, + created_at=created_at, + expires_at=expires_at, + ) + self._repository.save(new_session) self._logger.info( - 'sandbox_reused', + 'sandbox_created', attrs={ 'chat_id': command.chat_id, - 'session_id': session.session_id, - 'container_id': session.container_id, + 'session_id': new_session.session_id, + 'container_id': new_session.container_id, }, ) - return session - - if session is not None: - self._logger.info( - 'sandbox_replaced', - attrs={ - 'chat_id': command.chat_id, - 'session_id': session.session_id, - 'container_id': session.container_id, - }, - ) - self._runtime.stop(session.container_id) - self._repository.delete(session.session_id) - - expires_at = now + self._ttl - new_session = self._runtime.create( - session_id=_new_session_id(), - chat_id=command.chat_id, - created_at=now, - expires_at=expires_at, - ) - self._repository.save(new_session) - self._logger.info( - 'sandbox_created', - attrs={ - 'chat_id': command.chat_id, - 'session_id': new_session.session_id, - 'container_id': new_session.container_id, - }, - ) - return new_session + return new_session class CleanupExpiredSandboxes: def __init__( self, repository: SandboxSessionRepository, + locker: SandboxLifecycleLocker, runtime: SandboxRuntime, clock: Clock, logger: Logger, ) -> None: self._repository = repository + self._locker = locker self._runtime = runtime self._clock = clock self._logger = logger def execute(self) -> list[SandboxSession]: - now = self._clock.now() - expired_sessions = self._repository.list_expired(now) + expired_sessions = self._repository.list_expired(self._clock.now()) cleaned_sessions: list[SandboxSession] = [] for session in expired_sessions: - self._runtime.stop(session.container_id) - self._repository.delete(session.session_id) - cleaned_sessions.append(session) - self._logger.info( - 'sandbox_cleaned', - attrs={ - 'chat_id': session.chat_id, - 'session_id': session.session_id, - 'container_id': session.container_id, - }, - ) + with self._locker.lock(session.chat_id): + current_session = self._repository.get_active_by_chat_id( + session.chat_id + ) + now = self._clock.now() + if current_session is None: + continue + + if current_session.session_id != session.session_id: + continue + + if current_session.expires_at > now: + continue + + self._runtime.stop(current_session.container_id) + self._repository.delete(current_session.session_id) + cleaned_sessions.append(current_session) + self._logger.info( + 'sandbox_cleaned', + attrs={ + 'chat_id': current_session.chat_id, + 'session_id': current_session.session_id, + 'container_id': current_session.container_id, + }, + ) return cleaned_sessions