diff --git a/repository/sandbox_lock.py b/repository/sandbox_lock.py index b13cd65..ffa0e35 100644 --- a/repository/sandbox_lock.py +++ b/repository/sandbox_lock.py @@ -1,4 +1,5 @@ import threading +from dataclasses import dataclass from types import TracebackType from typing import Protocol from uuid import UUID @@ -12,12 +13,25 @@ class _SyncLock(Protocol): def release(self) -> None: ... +@dataclass(slots=True) +class _LockEntry: + lock: _SyncLock + users: int = 0 + + class _ChatLock(LockContext): - def __init__(self, lock: _SyncLock) -> None: - self._lock = lock + def __init__( + self, + locker: 'ProcessLocalSandboxLifecycleLocker', + chat_id: UUID, + entry: _LockEntry, + ) -> None: + self._locker = locker + self._chat_id = chat_id + self._entry = entry def __enter__(self) -> None: - self._lock.acquire() + self._entry.lock.acquire() def __exit__( self, @@ -25,20 +39,32 @@ class _ChatLock(LockContext): exc: BaseException | None, traceback: TracebackType | None, ) -> bool | None: - self._lock.release() + self._entry.lock.release() + self._locker._release(self._chat_id, self._entry) return None class ProcessLocalSandboxLifecycleLocker(SandboxLifecycleLocker): def __init__(self) -> None: self._registry_lock = threading.Lock() - self._locks_by_chat_id: dict[UUID, _SyncLock] = {} + self._locks_by_chat_id: dict[UUID, _LockEntry] = {} def lock(self, chat_id: UUID) -> LockContext: with self._registry_lock: - lock = self._locks_by_chat_id.get(chat_id) - if lock is None: - lock = threading.Lock() - self._locks_by_chat_id[chat_id] = lock + entry = self._locks_by_chat_id.get(chat_id) + if entry is None: + entry = _LockEntry(lock=threading.Lock()) + self._locks_by_chat_id[chat_id] = entry + entry.users += 1 - return _ChatLock(lock) + return _ChatLock(self, chat_id, entry) + + def _release(self, chat_id: UUID, entry: _LockEntry) -> None: + with self._registry_lock: + entry.users -= 1 + if entry.users != 0: + return + + current_entry = self._locks_by_chat_id.get(chat_id) + if current_entry is entry: + del self._locks_by_chat_id[chat_id] diff --git a/tasks.md b/tasks.md index 1101e17..494c655 100644 --- a/tasks.md +++ b/tasks.md @@ -209,7 +209,7 @@ ### M17. Управление жизненным циклом per-chat locks - Субагент: `feature-developer` -- Статус: pending +- Статус: completed - Зависимости: `M13` - Commit required: no - Scope: ограничить неограниченный рост registry locks по числу когда-либо увиденных `chat_id` diff --git a/test/test_sandbox_lock.py b/test/test_sandbox_lock.py new file mode 100644 index 0000000..1177cec --- /dev/null +++ b/test/test_sandbox_lock.py @@ -0,0 +1,93 @@ +import threading +from uuid import UUID + +from repository.sandbox_lock import ProcessLocalSandboxLifecycleLocker + +CHAT_ID = UUID('77777777-7777-7777-7777-777777777777') + + +class LockRace: + def __init__(self, locker: ProcessLocalSandboxLifecycleLocker) -> None: + self.locker = locker + self.entered_first = threading.Event() + self.second_requested = threading.Event() + self.second_entered = threading.Event() + self.release_first = threading.Event() + self.release_second = threading.Event() + self.errors: list[Exception] = [] + self.order: list[str] = [] + self.first_entry: object | None = None + + def run_first(self) -> None: + try: + with self.locker.lock(CHAT_ID): + self.first_entry = self.locker._locks_by_chat_id[CHAT_ID] + self.order.append('first_entered') + self.entered_first.set() + assert self.release_first.wait(timeout=1) + self.order.append('first_releasing') + except Exception as exc: + self.errors.append(exc) + + def run_second(self) -> None: + try: + assert self.entered_first.wait(timeout=1) + context = self.locker.lock(CHAT_ID) + self.second_requested.set() + + with context: + self.order.append('second_entered') + self.second_entered.set() + assert self.release_second.wait(timeout=1) + self.order.append('second_releasing') + except Exception as exc: + self.errors.append(exc) + + +def test_process_local_sandbox_lifecycle_locker_evicts_idle_lock() -> None: + locker = ProcessLocalSandboxLifecycleLocker() + + with locker.lock(CHAT_ID): + assert CHAT_ID in locker._locks_by_chat_id + assert len(locker._locks_by_chat_id) == 1 + + assert CHAT_ID not in locker._locks_by_chat_id + assert len(locker._locks_by_chat_id) == 0 + + +def test_process_local_sandbox_lifecycle_locker_keeps_shared_lock_for_waiters() -> None: + locker = ProcessLocalSandboxLifecycleLocker() + race = LockRace(locker) + first_thread = threading.Thread(target=race.run_first) + second_thread = threading.Thread(target=race.run_second) + + first_thread.start() + assert race.entered_first.wait(timeout=1) + + second_thread.start() + assert race.second_requested.wait(timeout=1) + assert len(locker._locks_by_chat_id) == 1 + assert locker._locks_by_chat_id[CHAT_ID] is race.first_entry + assert not race.second_entered.wait(timeout=0.1) + + race.release_first.set() + assert race.second_entered.wait(timeout=1) + assert len(locker._locks_by_chat_id) == 1 + assert locker._locks_by_chat_id[CHAT_ID] is race.first_entry + + race.release_second.set() + + first_thread.join(timeout=1) + second_thread.join(timeout=1) + + assert not first_thread.is_alive() + assert not second_thread.is_alive() + assert race.errors == [] + assert race.order == [ + 'first_entered', + 'first_releasing', + 'second_entered', + 'second_releasing', + ] + assert CHAT_ID not in locker._locks_by_chat_id + assert len(locker._locks_by_chat_id) == 0