master/repository/sandbox_lock.py
2026-04-02 23:52:36 +03:00

70 lines
1.9 KiB
Python

import threading
from dataclasses import dataclass
from types import TracebackType
from typing import Protocol
from uuid import UUID
from usecase.interface import LockContext, SandboxLifecycleLocker
class _SyncLock(Protocol):
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: ...
def release(self) -> None: ...
@dataclass(slots=True)
class _LockEntry:
lock: _SyncLock
users: int = 0
class _ChatLock(LockContext):
def __init__(
self,
locker: 'ProcessLocalSandboxLifecycleLocker',
chat_id: UUID,
entry: _LockEntry,
) -> None:
self._locker = locker
self._chat_id = chat_id
self._entry = entry
def __enter__(self) -> None:
self._entry.lock.acquire()
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
self._entry.lock.release()
self._locker._release(self._chat_id, self._entry)
return None
class ProcessLocalSandboxLifecycleLocker(SandboxLifecycleLocker):
def __init__(self) -> None:
self._registry_lock = threading.Lock()
self._locks_by_chat_id: dict[UUID, _LockEntry] = {}
def lock(self, chat_id: UUID) -> LockContext:
with self._registry_lock:
entry = self._locks_by_chat_id.get(chat_id)
if entry is None:
entry = _LockEntry(lock=threading.Lock())
self._locks_by_chat_id[chat_id] = entry
entry.users += 1
return _ChatLock(self, chat_id, entry)
def _release(self, chat_id: UUID, entry: _LockEntry) -> None:
with self._registry_lock:
entry.users -= 1
if entry.users != 0:
return
current_entry = self._locks_by_chat_id.get(chat_id)
if current_entry is entry:
del self._locks_by_chat_id[chat_id]