from dataclasses import dataclass from datetime import timedelta from uuid import UUID, uuid4 from domain.error import SandboxConflictError from domain.sandbox import SandboxSession from usecase.interface import ( Clock, Logger, Metrics, SandboxLifecycleLocker, SandboxRuntime, SandboxSessionRepository, Span, Tracer, ) @dataclass(frozen=True, slots=True) class CreateSandboxCommand: chat_id: UUID agent_id: str volume_host_path: str @dataclass(frozen=True, slots=True) class DeleteSandboxCommand: chat_id: UUID @dataclass(frozen=True, slots=True) class DeleteSandboxResult: chat_id: UUID result: str session_id: UUID | None = None container_id: str | None = None class DeleteSandbox: def __init__( self, repository: SandboxSessionRepository, locker: SandboxLifecycleLocker, runtime: SandboxRuntime, logger: Logger, metrics: Metrics, tracer: Tracer, ) -> None: self._repository = repository self._locker = locker self._runtime = runtime self._logger = logger self._metrics = metrics self._tracer = tracer def execute(self, command: DeleteSandboxCommand) -> DeleteSandboxResult: chat_id = command.chat_id with self._tracer.start_span( 'usecase.delete_sandbox', attrs={'chat.id': str(chat_id)}, ) as span: session: SandboxSession | None = None try: with self._locker.lock(chat_id): session = self._repository.get_active_by_chat_id(chat_id) if session is None: span.set_attribute('sandbox.result', 'not_found') self._metrics.increment( 'sandbox.delete.total', attrs=_result_metric_attrs('not_found'), ) self._logger.info( 'sandbox_delete_not_found', attrs={'chat_id': str(chat_id)}, ) return DeleteSandboxResult( chat_id=chat_id, result='not_found', ) _set_session_span_attrs(span, session) self._runtime.delete(session.container_id) self._repository.delete(session.session_id) _set_active_count(self._metrics, self._repository) span.set_attribute('sandbox.result', 'deleted') self._metrics.increment( 'sandbox.delete.total', attrs=_result_metric_attrs('deleted'), ) self._logger.info( 'sandbox_deleted', attrs=_sandbox_attrs(session), ) return DeleteSandboxResult( chat_id=chat_id, result='deleted', session_id=session.session_id, container_id=session.container_id, ) except Exception as exc: span.set_attribute('sandbox.result', 'error') self._metrics.increment( 'sandbox.delete.total', attrs=_result_metric_attrs('error'), ) span.record_error(exc) self._logger.error( 'sandbox_delete_failed', attrs=_delete_error_attrs(chat_id, session, exc), ) raise class CreateSandbox: def __init__( self, repository: SandboxSessionRepository, locker: SandboxLifecycleLocker, runtime: SandboxRuntime, clock: Clock, logger: Logger, metrics: Metrics, tracer: Tracer, ttl: timedelta, ) -> None: self._repository = repository self._locker = locker self._runtime = runtime self._clock = clock self._logger = logger self._metrics = metrics self._tracer = tracer self._ttl = ttl def execute(self, command: CreateSandboxCommand) -> SandboxSession: chat_id = command.chat_id with self._tracer.start_span( 'usecase.create_sandbox', attrs={ 'chat.id': str(chat_id), 'agent.id': command.agent_id, 'volume.host_path': command.volume_host_path, }, ) as span: try: with self._locker.lock(chat_id): session = self._repository.get_active_by_chat_id(chat_id) now = self._clock.now() if session is not None and session.expires_at > now: return self._reuse_or_conflict(command, session, span) return self._create_or_replace(command, session, span) except SandboxConflictError: raise except Exception as exc: span.set_attribute('sandbox.result', 'error') self._metrics.increment( 'sandbox.create.total', attrs=_result_metric_attrs('error'), ) span.record_error(exc) raise def _reuse_or_conflict( self, command: CreateSandboxCommand, session: SandboxSession, span: Span, ) -> SandboxSession: _set_session_span_attrs(span, session) if not _session_matches_command(session, command): reason = _conflict_reason(session, command) span.set_attribute('sandbox.result', 'conflict') span.set_attribute('sandbox.conflict.reason', reason) self._metrics.increment( 'sandbox.create.total', attrs=_conflict_metric_attrs(reason), ) self._logger.warning( 'sandbox_conflict', attrs=_conflict_attrs(session, command, reason), ) raise SandboxConflictError(str(command.chat_id)) span.set_attribute('sandbox.result', 'reused') self._metrics.increment( 'sandbox.create.total', attrs=_result_metric_attrs('reused'), ) self._logger.info( 'sandbox_reused', attrs=_sandbox_attrs(session), ) return session def _create_or_replace( self, command: CreateSandboxCommand, session: SandboxSession | None, span: Span, ) -> SandboxSession: result = 'created' new_session_id: UUID | None = None if session is not None: result = 'replaced' new_session_id = self._replace_expired_session(session, span) created_at = self._clock.now() expires_at = created_at + self._ttl if new_session_id is None: new_session_id = _new_session_id() span.set_attribute('session.id', str(new_session_id)) new_session = self._runtime.create( session_id=new_session_id, chat_id=command.chat_id, agent_id=command.agent_id, volume_host_path=command.volume_host_path, created_at=created_at, expires_at=expires_at, ) if result == 'replaced': _set_new_session_span_attrs(span, new_session) self._save_created_session(new_session) _set_active_count(self._metrics, self._repository) if result == 'replaced': span.set_attribute('session.id', str(new_session.session_id)) _set_session_span_attrs(span, new_session) span.set_attribute('sandbox.result', result) self._metrics.increment( 'sandbox.create.total', attrs=_result_metric_attrs(result), ) self._logger.info( 'sandbox_created', attrs=_sandbox_attrs(new_session), ) return new_session def _replace_expired_session( self, session: SandboxSession, span: Span, ) -> UUID: new_session_id = _new_session_id() _set_previous_session_span_attrs(span, session) span.set_attribute('sandbox.new_session.id', str(new_session_id)) self._logger.info( 'sandbox_replaced', attrs=_sandbox_attrs(session), ) self._runtime.stop(session.container_id) self._repository.delete(session.session_id) _set_active_count(self._metrics, self._repository) return new_session_id def _save_created_session(self, session: SandboxSession) -> None: try: self._repository.save(session) except Exception as exc: self._compensate_save_failure(session, exc) raise def _compensate_save_failure( self, session: SandboxSession, error: Exception, ) -> None: try: self._runtime.stop(session.container_id) except Exception as stop_error: _set_active_count(self._metrics, self._repository) raise error from stop_error _set_active_count(self._metrics, self._repository) class CleanupExpiredSandboxes: def __init__( self, repository: SandboxSessionRepository, locker: SandboxLifecycleLocker, runtime: SandboxRuntime, clock: Clock, logger: Logger, metrics: Metrics, tracer: Tracer, ) -> None: self._repository = repository self._locker = locker self._runtime = runtime self._clock = clock self._logger = logger self._metrics = metrics self._tracer = tracer def execute(self) -> list[SandboxSession]: cleaned_sessions: list[SandboxSession] = [] error_count = 0 with self._tracer.start_span( 'usecase.cleanup_expired_sandboxes', ) as span: try: expired_sessions = self._repository.list_expired(self._clock.now()) except Exception as exc: span.set_attribute('sandbox.result', 'error') self._metrics.increment( 'sandbox.cleanup.error.total', attrs=_cleanup_error_metric_attrs( type(exc).__name__, 'list_expired', ), ) span.record_error(exc) raise span.set_attribute('sandbox.expired_count', len(expired_sessions)) for session in expired_sessions: with self._tracer.start_span( 'usecase.cleanup_expired_sandbox', attrs=_sandbox_span_attrs(session), ) as cleanup_span: try: cleaned_session = self._cleanup_session(session) except Exception as exc: error_count += 1 cleanup_span.set_attribute('sandbox.result', 'error') cleanup_span.record_error(exc) self._metrics.increment( 'sandbox.cleanup.error.total', attrs=_error_metric_attrs(type(exc).__name__), ) attrs = _sandbox_attrs(session) attrs['error'] = type(exc).__name__ self._logger.error( 'sandbox_clean_failed', attrs=attrs, ) continue if cleaned_session is None: cleanup_span.set_attribute('sandbox.result', 'skipped') continue cleanup_span.set_attribute('sandbox.result', 'cleaned') cleaned_sessions.append(cleaned_session) self._metrics.increment( 'sandbox.cleanup.total', attrs=_result_metric_attrs('cleaned'), ) self._logger.info( 'sandbox_cleaned', attrs=_sandbox_attrs(cleaned_session), ) span.set_attribute('sandbox.cleaned_count', len(cleaned_sessions)) span.set_attribute('sandbox.error_count', error_count) span.set_attribute( 'sandbox.result', 'completed' if error_count == 0 else 'completed_with_errors', ) return cleaned_sessions def _cleanup_session(self, session: SandboxSession) -> SandboxSession | None: with self._locker.lock(session.chat_id): current_session = self._repository.get_active_by_chat_id(session.chat_id) now = self._clock.now() if current_session is None: return None if current_session.session_id != session.session_id: return None if current_session.expires_at > now: return None self._runtime.stop(current_session.container_id) self._repository.delete(current_session.session_id) _set_active_count(self._metrics, self._repository) return current_session def _new_session_id() -> UUID: return uuid4() def _session_matches_command( session: SandboxSession, command: CreateSandboxCommand, ) -> bool: return ( session.agent_id == command.agent_id and session.volume_host_path == command.volume_host_path ) def _conflict_reason( session: SandboxSession, command: CreateSandboxCommand, ) -> str: reasons: list[str] = [] if session.agent_id != command.agent_id: reasons.append('agent_id') if session.volume_host_path != command.volume_host_path: reasons.append('volume_host_path') return ','.join(reasons) def _set_session_span_attrs(span: Span, session: SandboxSession) -> None: span.set_attribute('session.id', str(session.session_id)) span.set_attribute('container.id', session.container_id) span.set_attribute('sandbox.session_agent.id', session.agent_id) span.set_attribute('sandbox.session_volume.host_path', session.volume_host_path) if session.endpoint is not None: span.set_attribute('sandbox.endpoint.ip', session.endpoint.ip) span.set_attribute('sandbox.endpoint.port', session.endpoint.port) def _set_previous_session_span_attrs(span: Span, session: SandboxSession) -> None: span.set_attribute('sandbox.previous_session.id', str(session.session_id)) span.set_attribute('sandbox.previous_container.id', session.container_id) span.set_attribute('sandbox.previous_agent.id', session.agent_id) span.set_attribute('sandbox.previous_volume.host_path', session.volume_host_path) def _set_new_session_span_attrs(span: Span, session: SandboxSession) -> None: span.set_attribute('sandbox.new_container.id', session.container_id) span.set_attribute('sandbox.new_agent.id', session.agent_id) span.set_attribute('sandbox.new_volume.host_path', session.volume_host_path) def _sandbox_attrs(session: SandboxSession) -> dict[str, str]: return { 'chat_id': str(session.chat_id), 'session_id': str(session.session_id), 'container_id': session.container_id, } def _sandbox_span_attrs(session: SandboxSession) -> dict[str, str]: return { 'chat.id': str(session.chat_id), 'session.id': str(session.session_id), 'container.id': session.container_id, } def _result_metric_attrs(result: str) -> dict[str, str]: return {'result': result} def _conflict_metric_attrs(reason: str) -> dict[str, str]: return { 'result': 'conflict', 'reason': reason, } def _conflict_attrs( session: SandboxSession, command: CreateSandboxCommand, reason: str, ) -> dict[str, str]: return { 'chat_id': str(command.chat_id), 'session_id': str(session.session_id), 'container_id': session.container_id, 'requested_agent_id': command.agent_id, 'session_agent_id': session.agent_id, 'requested_volume_host_path': command.volume_host_path, 'session_volume_host_path': session.volume_host_path, 'reason': reason, } def _delete_error_attrs( chat_id: UUID, session: SandboxSession | None, error: Exception, ) -> dict[str, str]: attrs = {'chat_id': str(chat_id), 'error': type(error).__name__} if session is not None: attrs.update(_sandbox_attrs(session)) return attrs def _error_metric_attrs(error_type: str) -> dict[str, str]: return {'error.type': error_type} def _cleanup_error_metric_attrs( error_type: str, reason: str, ) -> dict[str, str]: return { 'error.type': error_type, 'reason': reason, } def _set_active_count( metrics: Metrics, repository: SandboxSessionRepository, ) -> None: metrics.set('sandbox.active.count', repository.count_active())