[fix] race condition
This commit is contained in:
parent
fb974fff1e
commit
f5d13feaf9
7 changed files with 185 additions and 63 deletions
43
repository/sandbox_lock.py
Normal file
43
repository/sandbox_lock.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import threading
|
||||
from types import TracebackType
|
||||
from typing import Protocol
|
||||
|
||||
from usecase.interface import LockContext, SandboxLifecycleLocker
|
||||
|
||||
|
||||
class _SyncLock(Protocol):
|
||||
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: ...
|
||||
|
||||
def release(self) -> None: ...
|
||||
|
||||
|
||||
class _ChatLock(LockContext):
|
||||
def __init__(self, lock: _SyncLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self._lock.acquire()
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> bool | None:
|
||||
self._lock.release()
|
||||
return None
|
||||
|
||||
|
||||
class ProcessLocalSandboxLifecycleLocker(SandboxLifecycleLocker):
|
||||
def __init__(self) -> None:
|
||||
self._registry_lock = threading.Lock()
|
||||
self._locks_by_chat_id: dict[str, _SyncLock] = {}
|
||||
|
||||
def lock(self, chat_id: str) -> LockContext:
|
||||
with self._registry_lock:
|
||||
lock = self._locks_by_chat_id.get(chat_id)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
self._locks_by_chat_id[chat_id] = lock
|
||||
|
||||
return _ChatLock(lock)
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
from domain.sandbox import SandboxSession
|
||||
|
|
@ -7,22 +8,27 @@ from usecase.interface import SandboxSessionRepository
|
|||
class InMemorySandboxSessionRepository(SandboxSessionRepository):
|
||||
def __init__(self) -> None:
|
||||
self._sessions_by_chat_id: dict[str, SandboxSession] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_active_by_chat_id(self, chat_id: str) -> SandboxSession | None:
|
||||
return self._sessions_by_chat_id.get(chat_id)
|
||||
with self._lock:
|
||||
return self._sessions_by_chat_id.get(chat_id)
|
||||
|
||||
def list_expired(self, now: datetime) -> list[SandboxSession]:
|
||||
return [
|
||||
session
|
||||
for session in self._sessions_by_chat_id.values()
|
||||
if session.expires_at <= now
|
||||
]
|
||||
with self._lock:
|
||||
return [
|
||||
session
|
||||
for session in self._sessions_by_chat_id.values()
|
||||
if session.expires_at <= now
|
||||
]
|
||||
|
||||
def save(self, session: SandboxSession) -> None:
|
||||
self._sessions_by_chat_id[session.chat_id] = session
|
||||
with self._lock:
|
||||
self._sessions_by_chat_id[session.chat_id] = session
|
||||
|
||||
def delete(self, session_id: str) -> None:
|
||||
for chat_id, session in tuple(self._sessions_by_chat_id.items()):
|
||||
if session.session_id == session_id:
|
||||
del self._sessions_by_chat_id[chat_id]
|
||||
return
|
||||
with self._lock:
|
||||
for chat_id, session in tuple(self._sessions_by_chat_id.items()):
|
||||
if session.session_id == session_id:
|
||||
del self._sessions_by_chat_id[chat_id]
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue