import threading from dataclasses import dataclass from types import TracebackType from typing import Protocol from uuid import UUID from usecase.interface import LockContext, SandboxLifecycleLocker class _SyncLock(Protocol): def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: ... def release(self) -> None: ... @dataclass(slots=True) class _LockEntry: lock: _SyncLock users: int = 0 class _ChatLock(LockContext): def __init__( self, locker: 'ProcessLocalSandboxLifecycleLocker', chat_id: UUID, entry: _LockEntry, ) -> None: self._locker = locker self._chat_id = chat_id self._entry = entry def __enter__(self) -> None: self._entry.lock.acquire() def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, traceback: TracebackType | None, ) -> bool | None: self._entry.lock.release() self._locker._release(self._chat_id, self._entry) return None class ProcessLocalSandboxLifecycleLocker(SandboxLifecycleLocker): def __init__(self) -> None: self._registry_lock = threading.Lock() self._locks_by_chat_id: dict[UUID, _LockEntry] = {} def lock(self, chat_id: UUID) -> LockContext: with self._registry_lock: entry = self._locks_by_chat_id.get(chat_id) if entry is None: entry = _LockEntry(lock=threading.Lock()) self._locks_by_chat_id[chat_id] = entry entry.users += 1 return _ChatLock(self, chat_id, entry) def _release(self, chat_id: UUID, entry: _LockEntry) -> None: with self._registry_lock: entry.users -= 1 if entry.users != 0: return current_entry = self._locks_by_chat_id.get(chat_id) if current_entry is entry: del self._locks_by_chat_id[chat_id]