[fix] race condition

This commit is contained in:
Azamat 2026-04-02 20:56:26 +03:00
parent fb974fff1e
commit f5d13feaf9
7 changed files with 185 additions and 63 deletions

View file

@ -12,6 +12,7 @@ from adapter.docker.runtime import DockerSandboxRuntime
from adapter.observability.factory import build_observability
from adapter.observability.runtime import ObservabilityRuntime
from domain.user import User
from repository.sandbox_lock import ProcessLocalSandboxLifecycleLocker
from repository.sandbox_session import InMemorySandboxSessionRepository
from repository.user import InMemoryUserRepository
from usecase.interface import Clock
@ -85,6 +86,7 @@ def build_container(
observability.tracer, [User(id='123', email='aza@gglamer.ru', name='gglamer')]
)
sandbox_repository = InMemorySandboxSessionRepository()
sandbox_locker = ProcessLocalSandboxLifecycleLocker()
sandbox_runtime = DockerSandboxRuntime(app_config.sandbox, docker_client)
repositories = AppRepositories(
@ -99,6 +101,7 @@ def build_container(
),
create_sandbox=CreateSandbox(
repository=sandbox_repository,
locker=sandbox_locker,
runtime=sandbox_runtime,
clock=clock,
logger=observability.logger,
@ -106,6 +109,7 @@ def build_container(
),
cleanup_expired_sandboxes=CleanupExpiredSandboxes(
repository=sandbox_repository,
locker=sandbox_locker,
runtime=sandbox_runtime,
clock=clock,
logger=observability.logger,

View file

@ -0,0 +1,43 @@
import threading
from types import TracebackType
from typing import Protocol
from usecase.interface import LockContext, SandboxLifecycleLocker
class _SyncLock(Protocol):
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: ...
def release(self) -> None: ...
class _ChatLock(LockContext):
def __init__(self, lock: _SyncLock) -> None:
self._lock = lock
def __enter__(self) -> None:
self._lock.acquire()
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
self._lock.release()
return None
class ProcessLocalSandboxLifecycleLocker(SandboxLifecycleLocker):
def __init__(self) -> None:
self._registry_lock = threading.Lock()
self._locks_by_chat_id: dict[str, _SyncLock] = {}
def lock(self, chat_id: str) -> LockContext:
with self._registry_lock:
lock = self._locks_by_chat_id.get(chat_id)
if lock is None:
lock = threading.Lock()
self._locks_by_chat_id[chat_id] = lock
return _ChatLock(lock)

View file

@ -1,3 +1,4 @@
import threading
from datetime import datetime
from domain.sandbox import SandboxSession
@ -7,11 +8,14 @@ from usecase.interface import SandboxSessionRepository
class InMemorySandboxSessionRepository(SandboxSessionRepository):
def __init__(self) -> None:
self._sessions_by_chat_id: dict[str, SandboxSession] = {}
self._lock = threading.Lock()
def get_active_by_chat_id(self, chat_id: str) -> SandboxSession | None:
with self._lock:
return self._sessions_by_chat_id.get(chat_id)
def list_expired(self, now: datetime) -> list[SandboxSession]:
with self._lock:
return [
session
for session in self._sessions_by_chat_id.values()
@ -19,9 +23,11 @@ class InMemorySandboxSessionRepository(SandboxSessionRepository):
]
def save(self, session: SandboxSession) -> None:
with self._lock:
self._sessions_by_chat_id[session.chat_id] = session
def delete(self, session_id: str) -> None:
with self._lock:
for chat_id, session in tuple(self._sessions_by_chat_id.items()):
if session.session_id == session_id:
del self._sessions_by_chat_id[chat_id]

View file

@ -120,7 +120,7 @@
### M09. Сериализация lifecycle sandbox по `chat_id`
- Субагент: `feature-developer`
- Статус: pending
- Статус: completed
- Зависимости: `M08`
- Commit required: no
- Scope: убрать гонки между параллельными `create` и cleanup для одного `chat_id`

View file

@ -32,6 +32,23 @@ class FakeLogger:
self.messages.append(('error', message, attrs))
class FakeLockContext:
def __enter__(self) -> None:
return None
def __exit__(self, exc_type, exc, traceback) -> None:
return None
class FakeLocker:
def __init__(self) -> None:
self.chat_ids: list[str] = []
def lock(self, chat_id: str) -> FakeLockContext:
self.chat_ids.append(chat_id)
return FakeLockContext()
class FakeRuntime:
def __init__(self) -> None:
self.create_calls: list[dict[str, object]] = []
@ -80,8 +97,10 @@ def test_create_sandbox_reuses_active_session_when_not_expired() -> None:
repository.save(session)
runtime = FakeRuntime()
logger = FakeLogger()
locker = FakeLocker()
usecase = CreateSandbox(
repository=repository,
locker=locker,
runtime=runtime,
clock=FakeClock(now),
logger=logger,
@ -94,6 +113,7 @@ def test_create_sandbox_reuses_active_session_when_not_expired() -> None:
assert runtime.create_calls == []
assert runtime.stop_calls == []
assert repository.get_active_by_chat_id('chat-1') == session
assert locker.chat_ids == ['chat-1']
assert logger.messages == [
(
'info',
@ -123,8 +143,10 @@ def test_create_sandbox_replaces_expired_session_and_creates_new_one(
repository.save(expired_session)
runtime = FakeRuntime()
logger = FakeLogger()
locker = FakeLocker()
usecase = CreateSandbox(
repository=repository,
locker=locker,
runtime=runtime,
clock=FakeClock(now),
logger=logger,
@ -152,6 +174,7 @@ def test_create_sandbox_replaces_expired_session_and_creates_new_one(
expires_at=now + timedelta(minutes=5),
)
assert repository.get_active_by_chat_id('chat-1') == result
assert locker.chat_ids == ['chat-1']
assert logger.messages == [
(
'info',
@ -179,8 +202,10 @@ def test_create_sandbox_creates_new_session_when_none_exists() -> None:
repository = InMemorySandboxSessionRepository()
runtime = FakeRuntime()
logger = FakeLogger()
locker = FakeLocker()
usecase = CreateSandbox(
repository=repository,
locker=locker,
runtime=runtime,
clock=FakeClock(now),
logger=logger,
@ -203,6 +228,7 @@ def test_create_sandbox_creates_new_session_when_none_exists() -> None:
}
assert runtime.stop_calls == []
assert repository.get_active_by_chat_id('chat-1') == result
assert locker.chat_ids == ['chat-1']
assert logger.messages == [
(
'info',
@ -248,8 +274,10 @@ def test_cleanup_expired_sandboxes_stops_and_deletes_only_expired_sessions() ->
repository.save(active_session)
runtime = FakeRuntime()
logger = FakeLogger()
locker = FakeLocker()
usecase = CleanupExpiredSandboxes(
repository=repository,
locker=locker,
runtime=runtime,
clock=FakeClock(now),
logger=logger,
@ -262,6 +290,7 @@ def test_cleanup_expired_sandboxes_stops_and_deletes_only_expired_sessions() ->
assert repository.get_active_by_chat_id('chat-expired') is None
assert repository.get_active_by_chat_id('chat-boundary') is None
assert repository.get_active_by_chat_id('chat-active') == active_session
assert locker.chat_ids == ['chat-expired', 'chat-boundary']
assert logger.messages == [
(
'info',

View file

@ -28,6 +28,21 @@ class SandboxSessionRepository(Protocol):
def delete(self, session_id: str) -> None: ...
class LockContext(Protocol):
def __enter__(self) -> None: ...
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
traceback: TracebackType | None,
) -> bool | None: ...
class SandboxLifecycleLocker(Protocol):
def lock(self, chat_id: str) -> LockContext: ...
class SandboxRuntime(Protocol):
def create(
self,

View file

@ -3,7 +3,13 @@ from datetime import timedelta
from uuid import uuid4
from domain.sandbox import SandboxSession
from usecase.interface import Clock, Logger, SandboxRuntime, SandboxSessionRepository
from usecase.interface import (
Clock,
Logger,
SandboxLifecycleLocker,
SandboxRuntime,
SandboxSessionRepository,
)
@dataclass(frozen=True, slots=True)
@ -15,20 +21,23 @@ class CreateSandbox:
def __init__(
self,
repository: SandboxSessionRepository,
locker: SandboxLifecycleLocker,
runtime: SandboxRuntime,
clock: Clock,
logger: Logger,
ttl: timedelta,
) -> None:
self._repository = repository
self._locker = locker
self._runtime = runtime
self._clock = clock
self._logger = logger
self._ttl = ttl
def execute(self, command: CreateSandboxCommand) -> SandboxSession:
now = self._clock.now()
with self._locker.lock(command.chat_id):
session = self._repository.get_active_by_chat_id(command.chat_id)
now = self._clock.now()
if session is not None and session.expires_at > now:
self._logger.info(
@ -53,11 +62,12 @@ class CreateSandbox:
self._runtime.stop(session.container_id)
self._repository.delete(session.session_id)
expires_at = now + self._ttl
created_at = self._clock.now()
expires_at = created_at + self._ttl
new_session = self._runtime.create(
session_id=_new_session_id(),
chat_id=command.chat_id,
created_at=now,
created_at=created_at,
expires_at=expires_at,
)
self._repository.save(new_session)
@ -76,30 +86,45 @@ class CleanupExpiredSandboxes:
def __init__(
self,
repository: SandboxSessionRepository,
locker: SandboxLifecycleLocker,
runtime: SandboxRuntime,
clock: Clock,
logger: Logger,
) -> None:
self._repository = repository
self._locker = locker
self._runtime = runtime
self._clock = clock
self._logger = logger
def execute(self) -> list[SandboxSession]:
now = self._clock.now()
expired_sessions = self._repository.list_expired(now)
expired_sessions = self._repository.list_expired(self._clock.now())
cleaned_sessions: list[SandboxSession] = []
for session in expired_sessions:
self._runtime.stop(session.container_id)
self._repository.delete(session.session_id)
cleaned_sessions.append(session)
with self._locker.lock(session.chat_id):
current_session = self._repository.get_active_by_chat_id(
session.chat_id
)
now = self._clock.now()
if current_session is None:
continue
if current_session.session_id != session.session_id:
continue
if current_session.expires_at > now:
continue
self._runtime.stop(current_session.container_id)
self._repository.delete(current_session.session_id)
cleaned_sessions.append(current_session)
self._logger.info(
'sandbox_cleaned',
attrs={
'chat_id': session.chat_id,
'session_id': session.session_id,
'container_id': session.container_id,
'chat_id': current_session.chat_id,
'session_id': current_session.session_id,
'container_id': current_session.container_id,
},
)