[fix] race condition
This commit is contained in:
parent
fb974fff1e
commit
f5d13feaf9
7 changed files with 185 additions and 63 deletions
|
|
@ -12,6 +12,7 @@ from adapter.docker.runtime import DockerSandboxRuntime
|
||||||
from adapter.observability.factory import build_observability
|
from adapter.observability.factory import build_observability
|
||||||
from adapter.observability.runtime import ObservabilityRuntime
|
from adapter.observability.runtime import ObservabilityRuntime
|
||||||
from domain.user import User
|
from domain.user import User
|
||||||
|
from repository.sandbox_lock import ProcessLocalSandboxLifecycleLocker
|
||||||
from repository.sandbox_session import InMemorySandboxSessionRepository
|
from repository.sandbox_session import InMemorySandboxSessionRepository
|
||||||
from repository.user import InMemoryUserRepository
|
from repository.user import InMemoryUserRepository
|
||||||
from usecase.interface import Clock
|
from usecase.interface import Clock
|
||||||
|
|
@ -85,6 +86,7 @@ def build_container(
|
||||||
observability.tracer, [User(id='123', email='aza@gglamer.ru', name='gglamer')]
|
observability.tracer, [User(id='123', email='aza@gglamer.ru', name='gglamer')]
|
||||||
)
|
)
|
||||||
sandbox_repository = InMemorySandboxSessionRepository()
|
sandbox_repository = InMemorySandboxSessionRepository()
|
||||||
|
sandbox_locker = ProcessLocalSandboxLifecycleLocker()
|
||||||
sandbox_runtime = DockerSandboxRuntime(app_config.sandbox, docker_client)
|
sandbox_runtime = DockerSandboxRuntime(app_config.sandbox, docker_client)
|
||||||
|
|
||||||
repositories = AppRepositories(
|
repositories = AppRepositories(
|
||||||
|
|
@ -99,6 +101,7 @@ def build_container(
|
||||||
),
|
),
|
||||||
create_sandbox=CreateSandbox(
|
create_sandbox=CreateSandbox(
|
||||||
repository=sandbox_repository,
|
repository=sandbox_repository,
|
||||||
|
locker=sandbox_locker,
|
||||||
runtime=sandbox_runtime,
|
runtime=sandbox_runtime,
|
||||||
clock=clock,
|
clock=clock,
|
||||||
logger=observability.logger,
|
logger=observability.logger,
|
||||||
|
|
@ -106,6 +109,7 @@ def build_container(
|
||||||
),
|
),
|
||||||
cleanup_expired_sandboxes=CleanupExpiredSandboxes(
|
cleanup_expired_sandboxes=CleanupExpiredSandboxes(
|
||||||
repository=sandbox_repository,
|
repository=sandbox_repository,
|
||||||
|
locker=sandbox_locker,
|
||||||
runtime=sandbox_runtime,
|
runtime=sandbox_runtime,
|
||||||
clock=clock,
|
clock=clock,
|
||||||
logger=observability.logger,
|
logger=observability.logger,
|
||||||
|
|
|
||||||
43
repository/sandbox_lock.py
Normal file
43
repository/sandbox_lock.py
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
import threading
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from usecase.interface import LockContext, SandboxLifecycleLocker
|
||||||
|
|
||||||
|
|
||||||
|
class _SyncLock(Protocol):
|
||||||
|
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: ...
|
||||||
|
|
||||||
|
def release(self) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _ChatLock(LockContext):
|
||||||
|
def __init__(self, lock: _SyncLock) -> None:
|
||||||
|
self._lock = lock
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
self._lock.acquire()
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc: BaseException | None,
|
||||||
|
traceback: TracebackType | None,
|
||||||
|
) -> bool | None:
|
||||||
|
self._lock.release()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessLocalSandboxLifecycleLocker(SandboxLifecycleLocker):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._registry_lock = threading.Lock()
|
||||||
|
self._locks_by_chat_id: dict[str, _SyncLock] = {}
|
||||||
|
|
||||||
|
def lock(self, chat_id: str) -> LockContext:
|
||||||
|
with self._registry_lock:
|
||||||
|
lock = self._locks_by_chat_id.get(chat_id)
|
||||||
|
if lock is None:
|
||||||
|
lock = threading.Lock()
|
||||||
|
self._locks_by_chat_id[chat_id] = lock
|
||||||
|
|
||||||
|
return _ChatLock(lock)
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from domain.sandbox import SandboxSession
|
from domain.sandbox import SandboxSession
|
||||||
|
|
@ -7,22 +8,27 @@ from usecase.interface import SandboxSessionRepository
|
||||||
class InMemorySandboxSessionRepository(SandboxSessionRepository):
|
class InMemorySandboxSessionRepository(SandboxSessionRepository):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._sessions_by_chat_id: dict[str, SandboxSession] = {}
|
self._sessions_by_chat_id: dict[str, SandboxSession] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def get_active_by_chat_id(self, chat_id: str) -> SandboxSession | None:
|
def get_active_by_chat_id(self, chat_id: str) -> SandboxSession | None:
|
||||||
return self._sessions_by_chat_id.get(chat_id)
|
with self._lock:
|
||||||
|
return self._sessions_by_chat_id.get(chat_id)
|
||||||
|
|
||||||
def list_expired(self, now: datetime) -> list[SandboxSession]:
|
def list_expired(self, now: datetime) -> list[SandboxSession]:
|
||||||
return [
|
with self._lock:
|
||||||
session
|
return [
|
||||||
for session in self._sessions_by_chat_id.values()
|
session
|
||||||
if session.expires_at <= now
|
for session in self._sessions_by_chat_id.values()
|
||||||
]
|
if session.expires_at <= now
|
||||||
|
]
|
||||||
|
|
||||||
def save(self, session: SandboxSession) -> None:
|
def save(self, session: SandboxSession) -> None:
|
||||||
self._sessions_by_chat_id[session.chat_id] = session
|
with self._lock:
|
||||||
|
self._sessions_by_chat_id[session.chat_id] = session
|
||||||
|
|
||||||
def delete(self, session_id: str) -> None:
|
def delete(self, session_id: str) -> None:
|
||||||
for chat_id, session in tuple(self._sessions_by_chat_id.items()):
|
with self._lock:
|
||||||
if session.session_id == session_id:
|
for chat_id, session in tuple(self._sessions_by_chat_id.items()):
|
||||||
del self._sessions_by_chat_id[chat_id]
|
if session.session_id == session_id:
|
||||||
return
|
del self._sessions_by_chat_id[chat_id]
|
||||||
|
return
|
||||||
|
|
|
||||||
2
tasks.md
2
tasks.md
|
|
@ -120,7 +120,7 @@
|
||||||
### M09. Сериализация lifecycle sandbox по `chat_id`
|
### M09. Сериализация lifecycle sandbox по `chat_id`
|
||||||
|
|
||||||
- Субагент: `feature-developer`
|
- Субагент: `feature-developer`
|
||||||
- Статус: pending
|
- Статус: completed
|
||||||
- Зависимости: `M08`
|
- Зависимости: `M08`
|
||||||
- Commit required: no
|
- Commit required: no
|
||||||
- Scope: убрать гонки между параллельными `create` и cleanup для одного `chat_id`
|
- Scope: убрать гонки между параллельными `create` и cleanup для одного `chat_id`
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,23 @@ class FakeLogger:
|
||||||
self.messages.append(('error', message, attrs))
|
self.messages.append(('error', message, attrs))
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLockContext:
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, traceback) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLocker:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.chat_ids: list[str] = []
|
||||||
|
|
||||||
|
def lock(self, chat_id: str) -> FakeLockContext:
|
||||||
|
self.chat_ids.append(chat_id)
|
||||||
|
return FakeLockContext()
|
||||||
|
|
||||||
|
|
||||||
class FakeRuntime:
|
class FakeRuntime:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.create_calls: list[dict[str, object]] = []
|
self.create_calls: list[dict[str, object]] = []
|
||||||
|
|
@ -80,8 +97,10 @@ def test_create_sandbox_reuses_active_session_when_not_expired() -> None:
|
||||||
repository.save(session)
|
repository.save(session)
|
||||||
runtime = FakeRuntime()
|
runtime = FakeRuntime()
|
||||||
logger = FakeLogger()
|
logger = FakeLogger()
|
||||||
|
locker = FakeLocker()
|
||||||
usecase = CreateSandbox(
|
usecase = CreateSandbox(
|
||||||
repository=repository,
|
repository=repository,
|
||||||
|
locker=locker,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
clock=FakeClock(now),
|
clock=FakeClock(now),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
|
@ -94,6 +113,7 @@ def test_create_sandbox_reuses_active_session_when_not_expired() -> None:
|
||||||
assert runtime.create_calls == []
|
assert runtime.create_calls == []
|
||||||
assert runtime.stop_calls == []
|
assert runtime.stop_calls == []
|
||||||
assert repository.get_active_by_chat_id('chat-1') == session
|
assert repository.get_active_by_chat_id('chat-1') == session
|
||||||
|
assert locker.chat_ids == ['chat-1']
|
||||||
assert logger.messages == [
|
assert logger.messages == [
|
||||||
(
|
(
|
||||||
'info',
|
'info',
|
||||||
|
|
@ -123,8 +143,10 @@ def test_create_sandbox_replaces_expired_session_and_creates_new_one(
|
||||||
repository.save(expired_session)
|
repository.save(expired_session)
|
||||||
runtime = FakeRuntime()
|
runtime = FakeRuntime()
|
||||||
logger = FakeLogger()
|
logger = FakeLogger()
|
||||||
|
locker = FakeLocker()
|
||||||
usecase = CreateSandbox(
|
usecase = CreateSandbox(
|
||||||
repository=repository,
|
repository=repository,
|
||||||
|
locker=locker,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
clock=FakeClock(now),
|
clock=FakeClock(now),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
|
@ -152,6 +174,7 @@ def test_create_sandbox_replaces_expired_session_and_creates_new_one(
|
||||||
expires_at=now + timedelta(minutes=5),
|
expires_at=now + timedelta(minutes=5),
|
||||||
)
|
)
|
||||||
assert repository.get_active_by_chat_id('chat-1') == result
|
assert repository.get_active_by_chat_id('chat-1') == result
|
||||||
|
assert locker.chat_ids == ['chat-1']
|
||||||
assert logger.messages == [
|
assert logger.messages == [
|
||||||
(
|
(
|
||||||
'info',
|
'info',
|
||||||
|
|
@ -179,8 +202,10 @@ def test_create_sandbox_creates_new_session_when_none_exists() -> None:
|
||||||
repository = InMemorySandboxSessionRepository()
|
repository = InMemorySandboxSessionRepository()
|
||||||
runtime = FakeRuntime()
|
runtime = FakeRuntime()
|
||||||
logger = FakeLogger()
|
logger = FakeLogger()
|
||||||
|
locker = FakeLocker()
|
||||||
usecase = CreateSandbox(
|
usecase = CreateSandbox(
|
||||||
repository=repository,
|
repository=repository,
|
||||||
|
locker=locker,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
clock=FakeClock(now),
|
clock=FakeClock(now),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
|
@ -203,6 +228,7 @@ def test_create_sandbox_creates_new_session_when_none_exists() -> None:
|
||||||
}
|
}
|
||||||
assert runtime.stop_calls == []
|
assert runtime.stop_calls == []
|
||||||
assert repository.get_active_by_chat_id('chat-1') == result
|
assert repository.get_active_by_chat_id('chat-1') == result
|
||||||
|
assert locker.chat_ids == ['chat-1']
|
||||||
assert logger.messages == [
|
assert logger.messages == [
|
||||||
(
|
(
|
||||||
'info',
|
'info',
|
||||||
|
|
@ -248,8 +274,10 @@ def test_cleanup_expired_sandboxes_stops_and_deletes_only_expired_sessions() ->
|
||||||
repository.save(active_session)
|
repository.save(active_session)
|
||||||
runtime = FakeRuntime()
|
runtime = FakeRuntime()
|
||||||
logger = FakeLogger()
|
logger = FakeLogger()
|
||||||
|
locker = FakeLocker()
|
||||||
usecase = CleanupExpiredSandboxes(
|
usecase = CleanupExpiredSandboxes(
|
||||||
repository=repository,
|
repository=repository,
|
||||||
|
locker=locker,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
clock=FakeClock(now),
|
clock=FakeClock(now),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
|
@ -262,6 +290,7 @@ def test_cleanup_expired_sandboxes_stops_and_deletes_only_expired_sessions() ->
|
||||||
assert repository.get_active_by_chat_id('chat-expired') is None
|
assert repository.get_active_by_chat_id('chat-expired') is None
|
||||||
assert repository.get_active_by_chat_id('chat-boundary') is None
|
assert repository.get_active_by_chat_id('chat-boundary') is None
|
||||||
assert repository.get_active_by_chat_id('chat-active') == active_session
|
assert repository.get_active_by_chat_id('chat-active') == active_session
|
||||||
|
assert locker.chat_ids == ['chat-expired', 'chat-boundary']
|
||||||
assert logger.messages == [
|
assert logger.messages == [
|
||||||
(
|
(
|
||||||
'info',
|
'info',
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,21 @@ class SandboxSessionRepository(Protocol):
|
||||||
def delete(self, session_id: str) -> None: ...
|
def delete(self, session_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class LockContext(Protocol):
|
||||||
|
def __enter__(self) -> None: ...
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc: BaseException | None,
|
||||||
|
traceback: TracebackType | None,
|
||||||
|
) -> bool | None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxLifecycleLocker(Protocol):
|
||||||
|
def lock(self, chat_id: str) -> LockContext: ...
|
||||||
|
|
||||||
|
|
||||||
class SandboxRuntime(Protocol):
|
class SandboxRuntime(Protocol):
|
||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,13 @@ from datetime import timedelta
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from domain.sandbox import SandboxSession
|
from domain.sandbox import SandboxSession
|
||||||
from usecase.interface import Clock, Logger, SandboxRuntime, SandboxSessionRepository
|
from usecase.interface import (
|
||||||
|
Clock,
|
||||||
|
Logger,
|
||||||
|
SandboxLifecycleLocker,
|
||||||
|
SandboxRuntime,
|
||||||
|
SandboxSessionRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
|
|
@ -15,93 +21,112 @@ class CreateSandbox:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repository: SandboxSessionRepository,
|
repository: SandboxSessionRepository,
|
||||||
|
locker: SandboxLifecycleLocker,
|
||||||
runtime: SandboxRuntime,
|
runtime: SandboxRuntime,
|
||||||
clock: Clock,
|
clock: Clock,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
ttl: timedelta,
|
ttl: timedelta,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._repository = repository
|
self._repository = repository
|
||||||
|
self._locker = locker
|
||||||
self._runtime = runtime
|
self._runtime = runtime
|
||||||
self._clock = clock
|
self._clock = clock
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._ttl = ttl
|
self._ttl = ttl
|
||||||
|
|
||||||
def execute(self, command: CreateSandboxCommand) -> SandboxSession:
|
def execute(self, command: CreateSandboxCommand) -> SandboxSession:
|
||||||
now = self._clock.now()
|
with self._locker.lock(command.chat_id):
|
||||||
session = self._repository.get_active_by_chat_id(command.chat_id)
|
session = self._repository.get_active_by_chat_id(command.chat_id)
|
||||||
|
now = self._clock.now()
|
||||||
|
|
||||||
if session is not None and session.expires_at > now:
|
if session is not None and session.expires_at > now:
|
||||||
|
self._logger.info(
|
||||||
|
'sandbox_reused',
|
||||||
|
attrs={
|
||||||
|
'chat_id': command.chat_id,
|
||||||
|
'session_id': session.session_id,
|
||||||
|
'container_id': session.container_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
|
if session is not None:
|
||||||
|
self._logger.info(
|
||||||
|
'sandbox_replaced',
|
||||||
|
attrs={
|
||||||
|
'chat_id': command.chat_id,
|
||||||
|
'session_id': session.session_id,
|
||||||
|
'container_id': session.container_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self._runtime.stop(session.container_id)
|
||||||
|
self._repository.delete(session.session_id)
|
||||||
|
|
||||||
|
created_at = self._clock.now()
|
||||||
|
expires_at = created_at + self._ttl
|
||||||
|
new_session = self._runtime.create(
|
||||||
|
session_id=_new_session_id(),
|
||||||
|
chat_id=command.chat_id,
|
||||||
|
created_at=created_at,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
self._repository.save(new_session)
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
'sandbox_reused',
|
'sandbox_created',
|
||||||
attrs={
|
attrs={
|
||||||
'chat_id': command.chat_id,
|
'chat_id': command.chat_id,
|
||||||
'session_id': session.session_id,
|
'session_id': new_session.session_id,
|
||||||
'container_id': session.container_id,
|
'container_id': new_session.container_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return session
|
return new_session
|
||||||
|
|
||||||
if session is not None:
|
|
||||||
self._logger.info(
|
|
||||||
'sandbox_replaced',
|
|
||||||
attrs={
|
|
||||||
'chat_id': command.chat_id,
|
|
||||||
'session_id': session.session_id,
|
|
||||||
'container_id': session.container_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self._runtime.stop(session.container_id)
|
|
||||||
self._repository.delete(session.session_id)
|
|
||||||
|
|
||||||
expires_at = now + self._ttl
|
|
||||||
new_session = self._runtime.create(
|
|
||||||
session_id=_new_session_id(),
|
|
||||||
chat_id=command.chat_id,
|
|
||||||
created_at=now,
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
|
||||||
self._repository.save(new_session)
|
|
||||||
self._logger.info(
|
|
||||||
'sandbox_created',
|
|
||||||
attrs={
|
|
||||||
'chat_id': command.chat_id,
|
|
||||||
'session_id': new_session.session_id,
|
|
||||||
'container_id': new_session.container_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return new_session
|
|
||||||
|
|
||||||
|
|
||||||
class CleanupExpiredSandboxes:
|
class CleanupExpiredSandboxes:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repository: SandboxSessionRepository,
|
repository: SandboxSessionRepository,
|
||||||
|
locker: SandboxLifecycleLocker,
|
||||||
runtime: SandboxRuntime,
|
runtime: SandboxRuntime,
|
||||||
clock: Clock,
|
clock: Clock,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._repository = repository
|
self._repository = repository
|
||||||
|
self._locker = locker
|
||||||
self._runtime = runtime
|
self._runtime = runtime
|
||||||
self._clock = clock
|
self._clock = clock
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
|
|
||||||
def execute(self) -> list[SandboxSession]:
|
def execute(self) -> list[SandboxSession]:
|
||||||
now = self._clock.now()
|
expired_sessions = self._repository.list_expired(self._clock.now())
|
||||||
expired_sessions = self._repository.list_expired(now)
|
|
||||||
cleaned_sessions: list[SandboxSession] = []
|
cleaned_sessions: list[SandboxSession] = []
|
||||||
|
|
||||||
for session in expired_sessions:
|
for session in expired_sessions:
|
||||||
self._runtime.stop(session.container_id)
|
with self._locker.lock(session.chat_id):
|
||||||
self._repository.delete(session.session_id)
|
current_session = self._repository.get_active_by_chat_id(
|
||||||
cleaned_sessions.append(session)
|
session.chat_id
|
||||||
self._logger.info(
|
)
|
||||||
'sandbox_cleaned',
|
now = self._clock.now()
|
||||||
attrs={
|
if current_session is None:
|
||||||
'chat_id': session.chat_id,
|
continue
|
||||||
'session_id': session.session_id,
|
|
||||||
'container_id': session.container_id,
|
if current_session.session_id != session.session_id:
|
||||||
},
|
continue
|
||||||
)
|
|
||||||
|
if current_session.expires_at > now:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._runtime.stop(current_session.container_id)
|
||||||
|
self._repository.delete(current_session.session_id)
|
||||||
|
cleaned_sessions.append(current_session)
|
||||||
|
self._logger.info(
|
||||||
|
'sandbox_cleaned',
|
||||||
|
attrs={
|
||||||
|
'chat_id': current_session.chat_id,
|
||||||
|
'session_id': current_session.session_id,
|
||||||
|
'container_id': current_session.container_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return cleaned_sessions
|
return cleaned_sessions
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue