master/usecase/sandbox.py

510 lines
17 KiB
Python

from dataclasses import dataclass
from datetime import timedelta
from uuid import UUID, uuid4
from domain.error import SandboxConflictError
from domain.sandbox import SandboxSession
from usecase.interface import (
Clock,
Logger,
Metrics,
SandboxLifecycleLocker,
SandboxRuntime,
SandboxSessionRepository,
Span,
Tracer,
)
@dataclass(frozen=True, slots=True)
class CreateSandboxCommand:
chat_id: UUID
agent_id: str
volume_host_path: str
@dataclass(frozen=True, slots=True)
class DeleteSandboxCommand:
chat_id: UUID
@dataclass(frozen=True, slots=True)
class DeleteSandboxResult:
chat_id: UUID
result: str
session_id: UUID | None = None
container_id: str | None = None
class DeleteSandbox:
def __init__(
self,
repository: SandboxSessionRepository,
locker: SandboxLifecycleLocker,
runtime: SandboxRuntime,
logger: Logger,
metrics: Metrics,
tracer: Tracer,
) -> None:
self._repository = repository
self._locker = locker
self._runtime = runtime
self._logger = logger
self._metrics = metrics
self._tracer = tracer
def execute(self, command: DeleteSandboxCommand) -> DeleteSandboxResult:
chat_id = command.chat_id
with self._tracer.start_span(
'usecase.delete_sandbox',
attrs={'chat.id': str(chat_id)},
) as span:
session: SandboxSession | None = None
try:
with self._locker.lock(chat_id):
session = self._repository.get_active_by_chat_id(chat_id)
if session is None:
span.set_attribute('sandbox.result', 'not_found')
self._metrics.increment(
'sandbox.delete.total',
attrs=_result_metric_attrs('not_found'),
)
self._logger.info(
'sandbox_delete_not_found',
attrs={'chat_id': str(chat_id)},
)
return DeleteSandboxResult(
chat_id=chat_id,
result='not_found',
)
_set_session_span_attrs(span, session)
self._runtime.delete(session.container_id)
self._repository.delete(session.session_id)
_set_active_count(self._metrics, self._repository)
span.set_attribute('sandbox.result', 'deleted')
self._metrics.increment(
'sandbox.delete.total',
attrs=_result_metric_attrs('deleted'),
)
self._logger.info(
'sandbox_deleted',
attrs=_sandbox_attrs(session),
)
return DeleteSandboxResult(
chat_id=chat_id,
result='deleted',
session_id=session.session_id,
container_id=session.container_id,
)
except Exception as exc:
span.set_attribute('sandbox.result', 'error')
self._metrics.increment(
'sandbox.delete.total',
attrs=_result_metric_attrs('error'),
)
span.record_error(exc)
self._logger.error(
'sandbox_delete_failed',
attrs=_delete_error_attrs(chat_id, session, exc),
)
raise
class CreateSandbox:
def __init__(
self,
repository: SandboxSessionRepository,
locker: SandboxLifecycleLocker,
runtime: SandboxRuntime,
clock: Clock,
logger: Logger,
metrics: Metrics,
tracer: Tracer,
ttl: timedelta,
) -> None:
self._repository = repository
self._locker = locker
self._runtime = runtime
self._clock = clock
self._logger = logger
self._metrics = metrics
self._tracer = tracer
self._ttl = ttl
def execute(self, command: CreateSandboxCommand) -> SandboxSession:
chat_id = command.chat_id
with self._tracer.start_span(
'usecase.create_sandbox',
attrs={
'chat.id': str(chat_id),
'agent.id': command.agent_id,
'volume.host_path': command.volume_host_path,
},
) as span:
try:
with self._locker.lock(chat_id):
session = self._repository.get_active_by_chat_id(chat_id)
now = self._clock.now()
if session is not None and session.expires_at > now:
return self._reuse_or_conflict(command, session, span)
return self._create_or_replace(command, session, span)
except SandboxConflictError:
raise
except Exception as exc:
span.set_attribute('sandbox.result', 'error')
self._metrics.increment(
'sandbox.create.total',
attrs=_result_metric_attrs('error'),
)
span.record_error(exc)
raise
def _reuse_or_conflict(
self,
command: CreateSandboxCommand,
session: SandboxSession,
span: Span,
) -> SandboxSession:
_set_session_span_attrs(span, session)
if not _session_matches_command(session, command):
reason = _conflict_reason(session, command)
span.set_attribute('sandbox.result', 'conflict')
span.set_attribute('sandbox.conflict.reason', reason)
self._metrics.increment(
'sandbox.create.total',
attrs=_conflict_metric_attrs(reason),
)
self._logger.warning(
'sandbox_conflict',
attrs=_conflict_attrs(session, command, reason),
)
raise SandboxConflictError(str(command.chat_id))
span.set_attribute('sandbox.result', 'reused')
self._metrics.increment(
'sandbox.create.total',
attrs=_result_metric_attrs('reused'),
)
self._logger.info(
'sandbox_reused',
attrs=_sandbox_attrs(session),
)
return session
def _create_or_replace(
self,
command: CreateSandboxCommand,
session: SandboxSession | None,
span: Span,
) -> SandboxSession:
result = 'created'
new_session_id: UUID | None = None
if session is not None:
result = 'replaced'
new_session_id = self._replace_expired_session(session, span)
created_at = self._clock.now()
expires_at = created_at + self._ttl
if new_session_id is None:
new_session_id = _new_session_id()
span.set_attribute('session.id', str(new_session_id))
new_session = self._runtime.create(
session_id=new_session_id,
chat_id=command.chat_id,
agent_id=command.agent_id,
volume_host_path=command.volume_host_path,
created_at=created_at,
expires_at=expires_at,
)
if result == 'replaced':
_set_new_session_span_attrs(span, new_session)
self._save_created_session(new_session)
_set_active_count(self._metrics, self._repository)
if result == 'replaced':
span.set_attribute('session.id', str(new_session.session_id))
_set_session_span_attrs(span, new_session)
span.set_attribute('sandbox.result', result)
self._metrics.increment(
'sandbox.create.total',
attrs=_result_metric_attrs(result),
)
self._logger.info(
'sandbox_created',
attrs=_sandbox_attrs(new_session),
)
return new_session
def _replace_expired_session(
self,
session: SandboxSession,
span: Span,
) -> UUID:
new_session_id = _new_session_id()
_set_previous_session_span_attrs(span, session)
span.set_attribute('sandbox.new_session.id', str(new_session_id))
self._logger.info(
'sandbox_replaced',
attrs=_sandbox_attrs(session),
)
self._runtime.stop(session.container_id)
self._repository.delete(session.session_id)
_set_active_count(self._metrics, self._repository)
return new_session_id
def _save_created_session(self, session: SandboxSession) -> None:
try:
self._repository.save(session)
except Exception as exc:
self._compensate_save_failure(session, exc)
raise
def _compensate_save_failure(
self,
session: SandboxSession,
error: Exception,
) -> None:
try:
self._runtime.stop(session.container_id)
except Exception as stop_error:
_set_active_count(self._metrics, self._repository)
raise error from stop_error
_set_active_count(self._metrics, self._repository)
class CleanupExpiredSandboxes:
def __init__(
self,
repository: SandboxSessionRepository,
locker: SandboxLifecycleLocker,
runtime: SandboxRuntime,
clock: Clock,
logger: Logger,
metrics: Metrics,
tracer: Tracer,
) -> None:
self._repository = repository
self._locker = locker
self._runtime = runtime
self._clock = clock
self._logger = logger
self._metrics = metrics
self._tracer = tracer
def execute(self) -> list[SandboxSession]:
cleaned_sessions: list[SandboxSession] = []
error_count = 0
with self._tracer.start_span(
'usecase.cleanup_expired_sandboxes',
) as span:
try:
expired_sessions = self._repository.list_expired(self._clock.now())
except Exception as exc:
span.set_attribute('sandbox.result', 'error')
self._metrics.increment(
'sandbox.cleanup.error.total',
attrs=_cleanup_error_metric_attrs(
type(exc).__name__,
'list_expired',
),
)
span.record_error(exc)
raise
span.set_attribute('sandbox.expired_count', len(expired_sessions))
for session in expired_sessions:
with self._tracer.start_span(
'usecase.cleanup_expired_sandbox',
attrs=_sandbox_span_attrs(session),
) as cleanup_span:
try:
cleaned_session = self._cleanup_session(session)
except Exception as exc:
error_count += 1
cleanup_span.set_attribute('sandbox.result', 'error')
cleanup_span.record_error(exc)
self._metrics.increment(
'sandbox.cleanup.error.total',
attrs=_error_metric_attrs(type(exc).__name__),
)
attrs = _sandbox_attrs(session)
attrs['error'] = type(exc).__name__
self._logger.error(
'sandbox_clean_failed',
attrs=attrs,
)
continue
if cleaned_session is None:
cleanup_span.set_attribute('sandbox.result', 'skipped')
continue
cleanup_span.set_attribute('sandbox.result', 'cleaned')
cleaned_sessions.append(cleaned_session)
self._metrics.increment(
'sandbox.cleanup.total',
attrs=_result_metric_attrs('cleaned'),
)
self._logger.info(
'sandbox_cleaned',
attrs=_sandbox_attrs(cleaned_session),
)
span.set_attribute('sandbox.cleaned_count', len(cleaned_sessions))
span.set_attribute('sandbox.error_count', error_count)
span.set_attribute(
'sandbox.result',
'completed' if error_count == 0 else 'completed_with_errors',
)
return cleaned_sessions
def _cleanup_session(self, session: SandboxSession) -> SandboxSession | None:
with self._locker.lock(session.chat_id):
current_session = self._repository.get_active_by_chat_id(session.chat_id)
now = self._clock.now()
if current_session is None:
return None
if current_session.session_id != session.session_id:
return None
if current_session.expires_at > now:
return None
self._runtime.stop(current_session.container_id)
self._repository.delete(current_session.session_id)
_set_active_count(self._metrics, self._repository)
return current_session
def _new_session_id() -> UUID:
return uuid4()
def _session_matches_command(
session: SandboxSession,
command: CreateSandboxCommand,
) -> bool:
return (
session.agent_id == command.agent_id
and session.volume_host_path == command.volume_host_path
)
def _conflict_reason(
session: SandboxSession,
command: CreateSandboxCommand,
) -> str:
reasons: list[str] = []
if session.agent_id != command.agent_id:
reasons.append('agent_id')
if session.volume_host_path != command.volume_host_path:
reasons.append('volume_host_path')
return ','.join(reasons)
def _set_session_span_attrs(span: Span, session: SandboxSession) -> None:
span.set_attribute('session.id', str(session.session_id))
span.set_attribute('container.id', session.container_id)
span.set_attribute('sandbox.session_agent.id', session.agent_id)
span.set_attribute('sandbox.session_volume.host_path', session.volume_host_path)
if session.endpoint is not None:
span.set_attribute('sandbox.endpoint.ip', session.endpoint.ip)
span.set_attribute('sandbox.endpoint.port', session.endpoint.port)
def _set_previous_session_span_attrs(span: Span, session: SandboxSession) -> None:
span.set_attribute('sandbox.previous_session.id', str(session.session_id))
span.set_attribute('sandbox.previous_container.id', session.container_id)
span.set_attribute('sandbox.previous_agent.id', session.agent_id)
span.set_attribute('sandbox.previous_volume.host_path', session.volume_host_path)
def _set_new_session_span_attrs(span: Span, session: SandboxSession) -> None:
span.set_attribute('sandbox.new_container.id', session.container_id)
span.set_attribute('sandbox.new_agent.id', session.agent_id)
span.set_attribute('sandbox.new_volume.host_path', session.volume_host_path)
def _sandbox_attrs(session: SandboxSession) -> dict[str, str]:
return {
'chat_id': str(session.chat_id),
'session_id': str(session.session_id),
'container_id': session.container_id,
}
def _sandbox_span_attrs(session: SandboxSession) -> dict[str, str]:
return {
'chat.id': str(session.chat_id),
'session.id': str(session.session_id),
'container.id': session.container_id,
}
def _result_metric_attrs(result: str) -> dict[str, str]:
return {'result': result}
def _conflict_metric_attrs(reason: str) -> dict[str, str]:
return {
'result': 'conflict',
'reason': reason,
}
def _conflict_attrs(
session: SandboxSession,
command: CreateSandboxCommand,
reason: str,
) -> dict[str, str]:
return {
'chat_id': str(command.chat_id),
'session_id': str(session.session_id),
'container_id': session.container_id,
'requested_agent_id': command.agent_id,
'session_agent_id': session.agent_id,
'requested_volume_host_path': command.volume_host_path,
'session_volume_host_path': session.volume_host_path,
'reason': reason,
}
def _delete_error_attrs(
chat_id: UUID,
session: SandboxSession | None,
error: Exception,
) -> dict[str, str]:
attrs = {'chat_id': str(chat_id), 'error': type(error).__name__}
if session is not None:
attrs.update(_sandbox_attrs(session))
return attrs
def _error_metric_attrs(error_type: str) -> dict[str, str]:
return {'error.type': error_type}
def _cleanup_error_metric_attrs(
error_type: str,
reason: str,
) -> dict[str, str]:
return {
'error.type': error_type,
'reason': reason,
}
def _set_active_count(
metrics: Metrics,
repository: SandboxSessionRepository,
) -> None:
metrics.set('sandbox.active.count', repository.count_active())