from dataclasses import replace from datetime import UTC, datetime, timedelta from pathlib import Path from types import TracebackType from typing import Any, TypedDict from uuid import UUID import pytest from docker import DockerClient from docker.errors import DockerException, NotFound from docker.types import Mount from adapter.config.model import SandboxConfig from adapter.docker.runtime import DockerSandboxRuntime from adapter.observability.noop import NoopMetrics, NoopTracer from domain.error import SandboxError, SandboxStartError from domain.sandbox import SandboxEndpoint, SandboxSession, SandboxStatus from usecase.interface import Attrs, AttrValue CHAT_ID = UUID('123e4567-e89b-12d3-a456-426614174000') NON_CANONICAL_CHAT_ID = '123E4567E89B12D3A456426614174000' SESSION_ID = UUID('00000000-0000-0000-0000-000000000010') AGENT_ID = 'agent-alpha' def _network_attrs(network_name: str = 'sandbox', ip: str = '172.20.0.8') -> dict[str, object]: return { 'NetworkSettings': { 'Networks': { network_name: { 'IPAddress': ip, } } } } class FakeContainer: def __init__( self, container_id: str, *, network_name: str = 'sandbox', ip: str = '172.20.0.8', ) -> None: self.id = container_id self.stop_calls = 0 self.remove_calls: list[dict[str, bool]] = [] self.reload_calls = 0 self.attrs = _network_attrs(network_name, ip) self.labels: dict[str, str] = {} def stop(self) -> None: self.stop_calls += 1 def reload(self) -> None: self.reload_calls += 1 def remove(self, *, force: bool) -> None: self.remove_calls.append({'force': force}) class FakeListedContainer(FakeContainer): def __init__( self, container_id: str, *, labels: dict[str, str], created_at: str, network_name: str = 'sandbox', ip: str = '172.20.0.8', ) -> None: super().__init__(container_id, network_name=network_name, ip=ip) self.labels = labels self.attrs['Created'] = created_at class FailingStopContainer(FakeListedContainer): def __init__( self, container_id: str, *, labels: dict[str, str], created_at: str, error: Exception, ) -> None: super().__init__( container_id, labels=labels, created_at=created_at, ) self._error = error def stop(self) -> None: self.stop_calls += 1 raise self._error class RunKwargs(TypedDict): detach: bool environment: dict[str, str] labels: dict[str, str] mounts: list[Mount] network: str class RunCall(TypedDict): args: tuple[str] kwargs: RunKwargs class FakeContainers: def __init__(self, run_result: FakeContainer | None = None) -> None: self.run_calls: list[RunCall] = [] self.get_calls: list[str] = [] self.list_calls: list[dict[str, object]] = [] self.run_result = run_result or FakeContainer('container-123') self.get_result: FakeContainer | Exception | None = None self.list_result: list[object] = [] self.list_error: Exception | None = None def run( self, image: str, *, detach: bool, environment: dict[str, str], labels: dict[str, str], mounts: list[Mount], network: str, ) -> FakeContainer: self.run_calls.append( { 'args': (image,), 'kwargs': { 'detach': detach, 'environment': environment, 'labels': labels, 'mounts': mounts, 'network': network, }, } ) return self.run_result def get(self, container_id: str) -> FakeContainer: self.get_calls.append(container_id) if isinstance(self.get_result, Exception): raise self.get_result if self.get_result is None: raise AssertionError('missing get result') return self.get_result def list(self, *, filters: dict[str, list[str]]) -> list[object]: self.list_calls.append({'filters': filters}) if self.list_error is not None: raise self.list_error return self.list_result class FakeDockerClient(DockerClient): def __init__(self, containers: FakeContainers) -> None: self._containers = containers @property def containers(self) -> Any: return self._containers class RecordingMetrics: def __init__(self) -> None: self.increment_calls: list[tuple[str, int, Attrs | None]] = [] self.record_calls: list[tuple[str, float, Attrs | None]] = [] self.set_calls: list[tuple[str, int | float, Attrs | None]] = [] def increment( self, name: str, value: int = 1, attrs: Attrs | None = None, ) -> None: self.increment_calls.append((name, value, attrs)) def record( self, name: str, value: float, attrs: Attrs | None = None, ) -> None: self.record_calls.append((name, value, attrs)) def set( self, name: str, value: int | float, attrs: Attrs | None = None, ) -> None: self.set_calls.append((name, value, attrs)) class RecordingSpan: def __init__(self) -> None: self.attrs: dict[str, AttrValue] = {} self.errors: list[Exception] = [] def set_attribute(self, name: str, value: AttrValue) -> None: self.attrs[name] = value def record_error(self, error: Exception) -> None: self.errors.append(error) class RecordingSpanContext: def __init__(self, span: RecordingSpan) -> None: self._span = span def __enter__(self) -> RecordingSpan: return self._span def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, traceback: TracebackType | None, ) -> bool | None: return None class RecordingTracer: def __init__(self) -> None: self.spans: list[tuple[str, Attrs | None, RecordingSpan]] = [] def start_span( self, name: str, attrs: Attrs | None = None, ) -> RecordingSpanContext: span = RecordingSpan() self.spans.append((name, attrs, span)) return RecordingSpanContext(span) def _attrs_include( actual: Attrs | dict[str, AttrValue] | None, expected: dict[str, AttrValue], ) -> bool: if actual is None: return False return all(actual.get(name) == value for name, value in expected.items()) def _find_span( tracer: RecordingTracer, name: str, attrs: dict[str, AttrValue] | None = None, span_attrs: dict[str, AttrValue] | None = None, ) -> RecordingSpan: for recorded_name, recorded_attrs, span in tracer.spans: if recorded_name != name: continue if attrs is not None and not _attrs_include(recorded_attrs, attrs): continue if span_attrs is not None and not _attrs_include(span.attrs, span_attrs): continue return span raise AssertionError(f'missing span {name}') def _find_increment_call( metrics: RecordingMetrics, name: str, *, value: int = 1, attrs: dict[str, AttrValue] | None = None, ) -> tuple[str, int, Attrs | None]: for recorded_name, recorded_value, recorded_attrs in metrics.increment_calls: if recorded_name != name or recorded_value != value: continue if attrs is not None and not _attrs_include(recorded_attrs, attrs): continue return recorded_name, recorded_value, recorded_attrs raise AssertionError(f'missing increment metric {name}') def _find_record_call( metrics: RecordingMetrics, name: str, *, attrs: dict[str, AttrValue] | None = None, ) -> tuple[str, float, Attrs | None]: for recorded_name, recorded_value, recorded_attrs in metrics.record_calls: if recorded_name != name: continue if attrs is not None and not _attrs_include(recorded_attrs, attrs): continue return recorded_name, recorded_value, recorded_attrs raise AssertionError(f'missing record metric {name}') def build_config(tmp_path: Path) -> SandboxConfig: return SandboxConfig( image='sandbox:latest', network_name='sandbox', agent_service_port=8000, ttl_seconds=300, cleanup_interval_seconds=60, chats_root=str(tmp_path / 'chats'), dependencies_host_path=str(tmp_path / 'dependencies'), lambda_tools_host_path=str(tmp_path / 'lambda-tools'), chat_mount_path='/workspace/chat', dependencies_mount_path='/workspace/dependencies', lambda_tools_mount_path='/workspace/lambda-tools', volume_mount_path='/workspace/volume', extra_env={}, ) def build_runtime( config: SandboxConfig, containers: FakeContainers, ) -> DockerSandboxRuntime: return DockerSandboxRuntime( config, FakeDockerClient(containers), NoopMetrics(), NoopTracer(), ) def test_runtime_create_applies_mount_policy_and_labels_with_canonical_chat_id( tmp_path: Path, ) -> None: config = build_config(tmp_path) (tmp_path / 'dependencies').mkdir() (tmp_path / 'lambda-tools').mkdir() containers = FakeContainers() runtime = build_runtime(config, containers) created_at = datetime(2026, 4, 2, 12, 0, tzinfo=UTC) expires_at = created_at + timedelta(minutes=5) session = runtime.create( session_id=SESSION_ID, chat_id=UUID(NON_CANONICAL_CHAT_ID), agent_id=AGENT_ID, volume_host_path=str(tmp_path / 'request-volume'), created_at=created_at, expires_at=expires_at, ) assert session.session_id == SESSION_ID assert session.chat_id == CHAT_ID assert session.container_id == 'container-123' assert session.status is SandboxStatus.RUNNING assert session.created_at == created_at assert session.expires_at == expires_at assert session.agent_id == AGENT_ID assert session.volume_host_path == str( (tmp_path / 'request-volume').resolve(strict=False) ) assert session.endpoint == SandboxEndpoint(ip='172.20.0.8', port=8000) assert (tmp_path / 'chats' / str(CHAT_ID)).is_dir() call = containers.run_calls[0] assert call['args'] == ('sandbox:latest',) assert call['kwargs']['detach'] is True assert call['kwargs']['environment'] == {'AGENT_ID': AGENT_ID} assert call['kwargs']['network'] == 'sandbox' assert call['kwargs']['labels'] == { 'session_id': str(SESSION_ID), 'chat_id': str(CHAT_ID), 'expires_at': expires_at.isoformat(), 'agent_id': AGENT_ID, 'volume_host_path': str((tmp_path / 'request-volume').resolve(strict=False)), 'endpoint_port': '8000', } mounts = call['kwargs']['mounts'] assert [dict(mount) for mount in mounts] == [ { 'Target': '/workspace/chat', 'Source': str((tmp_path / 'chats' / str(CHAT_ID)).resolve(strict=False)), 'Type': 'bind', 'ReadOnly': False, }, { 'Target': '/workspace/dependencies', 'Source': str((tmp_path / 'dependencies').resolve(strict=False)), 'Type': 'bind', 'ReadOnly': True, }, { 'Target': '/workspace/lambda-tools', 'Source': str((tmp_path / 'lambda-tools').resolve(strict=False)), 'Type': 'bind', 'ReadOnly': True, }, { 'Target': '/workspace/volume', 'Source': str((tmp_path / 'request-volume').resolve(strict=False)), 'Type': 'bind', 'ReadOnly': False, }, ] def test_runtime_create_uses_configured_network_for_endpoint(tmp_path: Path) -> None: config = replace( build_config(tmp_path), network_name='agent-net', agent_service_port=9000, ) (tmp_path / 'dependencies').mkdir() (tmp_path / 'lambda-tools').mkdir() containers = FakeContainers( run_result=FakeContainer( 'container-456', network_name='agent-net', ip='10.42.0.7', ) ) runtime = build_runtime(config, containers) created_at = datetime(2026, 4, 2, 12, 0, tzinfo=UTC) expires_at = created_at + timedelta(minutes=5) session = runtime.create( session_id=SESSION_ID, chat_id=CHAT_ID, agent_id=AGENT_ID, volume_host_path=str(tmp_path / 'request-volume'), created_at=created_at, expires_at=expires_at, ) assert containers.run_calls[0]['kwargs']['network'] == 'agent-net' assert session.endpoint == SandboxEndpoint(ip='10.42.0.7', port=9000) def test_runtime_create_removes_container_when_endpoint_extraction_fails( tmp_path: Path, ) -> None: config = build_config(tmp_path) (tmp_path / 'dependencies').mkdir() (tmp_path / 'lambda-tools').mkdir() created_container = FakeContainer( 'container-789', network_name='unexpected-net', ) containers = FakeContainers(run_result=created_container) runtime = build_runtime(config, containers) with pytest.raises(SandboxStartError) as excinfo: runtime.create( session_id=SESSION_ID, chat_id=CHAT_ID, agent_id=AGENT_ID, volume_host_path=str(tmp_path / 'request-volume'), created_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC), expires_at=datetime(2026, 4, 2, 12, 5, tzinfo=UTC), ) assert str(excinfo.value) == 'sandbox_start_failed' assert containers.run_calls assert created_container.remove_calls == [{'force': True}] def test_runtime_create_applies_request_volume_bind_as_rw(tmp_path: Path) -> None: config = build_config(tmp_path) (tmp_path / 'dependencies').mkdir() (tmp_path / 'lambda-tools').mkdir() containers = FakeContainers() runtime = build_runtime(config, containers) created_at = datetime(2026, 4, 2, 12, 0, tzinfo=UTC) expires_at = created_at + timedelta(minutes=5) volume_host_path = str(tmp_path / 'request-volume') runtime.create( session_id=SESSION_ID, chat_id=CHAT_ID, agent_id=AGENT_ID, volume_host_path=volume_host_path, created_at=created_at, expires_at=expires_at, ) mounts = [dict(mount) for mount in containers.run_calls[0]['kwargs']['mounts']] assert { 'Target': '/workspace/volume', 'Source': str((tmp_path / 'request-volume').resolve(strict=False)), 'Type': 'bind', 'ReadOnly': False, } in mounts def test_runtime_create_records_observability(tmp_path: Path) -> None: config = build_config(tmp_path) (tmp_path / 'dependencies').mkdir() (tmp_path / 'lambda-tools').mkdir() containers = FakeContainers() metrics = RecordingMetrics() tracer = RecordingTracer() runtime = DockerSandboxRuntime( config, FakeDockerClient(containers), metrics, tracer, ) created_at = datetime(2026, 4, 2, 12, 0, tzinfo=UTC) expires_at = created_at + timedelta(minutes=5) session = runtime.create( session_id=SESSION_ID, chat_id=CHAT_ID, agent_id=AGENT_ID, volume_host_path=str(tmp_path / 'request-volume'), created_at=created_at, expires_at=expires_at, ) assert session.container_id == 'container-123' duration_call = _find_record_call( metrics, 'sandbox.runtime.create.duration_ms', attrs={'operation': 'create', 'result': 'created'}, ) assert duration_call[1] >= 0 span = _find_span( tracer, 'adapter.docker.create_sandbox', {'chat.id': str(CHAT_ID), 'session.id': str(SESSION_ID)}, { 'container.id': 'container-123', 'sandbox.result': 'created', }, ) assert not span.errors def test_runtime_create_raises_start_error_when_container_id_is_missing( tmp_path: Path, ) -> None: config = build_config(tmp_path) (tmp_path / 'dependencies').mkdir() (tmp_path / 'lambda-tools').mkdir() containers = FakeContainers(run_result=FakeContainer('')) runtime = build_runtime(config, containers) with pytest.raises(SandboxStartError) as excinfo: runtime.create( session_id=SESSION_ID, chat_id=CHAT_ID, agent_id=AGENT_ID, volume_host_path=str(tmp_path / 'request-volume'), created_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC), expires_at=datetime(2026, 4, 2, 12, 5, tzinfo=UTC), ) assert str(excinfo.value) == 'sandbox_start_failed' assert excinfo.value.chat_id == str(CHAT_ID) def test_runtime_create_error_records_observability_when_container_id_missing( tmp_path: Path, ) -> None: config = build_config(tmp_path) (tmp_path / 'dependencies').mkdir() (tmp_path / 'lambda-tools').mkdir() containers = FakeContainers(run_result=FakeContainer('')) metrics = RecordingMetrics() tracer = RecordingTracer() runtime = DockerSandboxRuntime( config, FakeDockerClient(containers), metrics, tracer, ) with pytest.raises(SandboxStartError) as excinfo: runtime.create( session_id=SESSION_ID, chat_id=CHAT_ID, agent_id=AGENT_ID, volume_host_path=str(tmp_path / 'request-volume'), created_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC), expires_at=datetime(2026, 4, 2, 12, 5, tzinfo=UTC), ) assert str(excinfo.value) == 'sandbox_start_failed' _find_increment_call( metrics, 'sandbox.runtime.error.total', attrs={'operation': 'create', 'error.type': 'ValueError'}, ) duration_call = _find_record_call( metrics, 'sandbox.runtime.create.duration_ms', attrs={'operation': 'create', 'result': 'error'}, ) assert duration_call[1] >= 0 span = _find_span( tracer, 'adapter.docker.create_sandbox', {'chat.id': str(CHAT_ID), 'session.id': str(SESSION_ID)}, {'sandbox.result': 'error'}, ) assert excinfo.value in span.errors def test_runtime_stop_ignores_missing_container(tmp_path: Path) -> None: config = build_config(tmp_path) containers = FakeContainers() containers.get_result = NotFound('missing') metrics = RecordingMetrics() tracer = RecordingTracer() runtime = DockerSandboxRuntime( config, FakeDockerClient(containers), metrics, tracer, ) runtime.stop('container-123') assert containers.get_calls == ['container-123'] duration_call = _find_record_call( metrics, 'sandbox.runtime.stop.duration_ms', attrs={'operation': 'stop', 'result': 'not_found'}, ) assert duration_call[1] >= 0 span = _find_span( tracer, 'adapter.docker.stop_sandbox', {'container.id': 'container-123'}, {'sandbox.result': 'not_found'}, ) assert not span.errors stop_error_calls = [ call for call in metrics.increment_calls if call[0] == 'sandbox.runtime.error.total' and call[2] is not None and call[2].get('operation') == 'stop' ] assert stop_error_calls == [] def test_runtime_stop_wraps_docker_errors(tmp_path: Path) -> None: config = build_config(tmp_path) containers = FakeContainers() containers.get_result = FailingStopContainer( 'container-123', labels={ 'session_id': str(SESSION_ID), 'chat_id': str(CHAT_ID), 'expires_at': '2026-04-02T12:05:00+00:00', }, created_at='2026-04-02T12:00:00Z', error=DockerException('boom'), ) metrics = RecordingMetrics() tracer = RecordingTracer() runtime = DockerSandboxRuntime( config, FakeDockerClient(containers), metrics, tracer, ) with pytest.raises(SandboxError) as excinfo: runtime.stop('container-123') assert str(excinfo.value) == 'sandbox_stop_failed' _find_increment_call( metrics, 'sandbox.runtime.error.total', attrs={'operation': 'stop', 'error.type': 'DockerException'}, ) duration_call = _find_record_call( metrics, 'sandbox.runtime.stop.duration_ms', attrs={'operation': 'stop', 'result': 'error'}, ) assert duration_call[1] >= 0 span = _find_span( tracer, 'adapter.docker.stop_sandbox', {'container.id': 'container-123'}, { 'session.id': str(SESSION_ID), 'chat.id': str(CHAT_ID), 'sandbox.result': 'error', }, ) cause = excinfo.value.__cause__ assert isinstance(cause, DockerException) assert cause in span.errors def test_runtime_stop_records_observability_on_success(tmp_path: Path) -> None: config = build_config(tmp_path) containers = FakeContainers() container = FakeListedContainer( 'container-123', labels={ 'session_id': str(SESSION_ID), 'chat_id': str(CHAT_ID), 'expires_at': '2026-04-02T12:05:00+00:00', }, created_at='2026-04-02T12:00:00Z', ) containers.get_result = container metrics = RecordingMetrics() tracer = RecordingTracer() runtime = DockerSandboxRuntime( config, FakeDockerClient(containers), metrics, tracer, ) runtime.stop('container-123') assert container.stop_calls == 1 duration_call = _find_record_call( metrics, 'sandbox.runtime.stop.duration_ms', attrs={'operation': 'stop', 'result': 'stopped'}, ) assert duration_call[1] >= 0 span = _find_span( tracer, 'adapter.docker.stop_sandbox', {'container.id': 'container-123'}, { 'session.id': str(SESSION_ID), 'chat.id': str(CHAT_ID), 'sandbox.result': 'stopped', }, ) assert not span.errors stop_error_calls = [ call for call in metrics.increment_calls if call[0] == 'sandbox.runtime.error.total' and call[2] is not None and call[2].get('operation') == 'stop' ] assert stop_error_calls == [] def test_runtime_delete_removes_container_with_force(tmp_path: Path) -> None: config = build_config(tmp_path) containers = FakeContainers() container = FakeListedContainer( 'container-123', labels={ 'session_id': str(SESSION_ID), 'chat_id': str(CHAT_ID), 'expires_at': '2026-04-02T12:05:00+00:00', }, created_at='2026-04-02T12:00:00Z', ) containers.get_result = container runtime = build_runtime(config, containers) runtime.delete('container-123') assert containers.get_calls == ['container-123'] assert container.remove_calls == [{'force': True}] def test_runtime_delete_ignores_missing_container(tmp_path: Path) -> None: config = build_config(tmp_path) containers = FakeContainers() containers.get_result = NotFound('missing') runtime = build_runtime(config, containers) runtime.delete('container-123') assert containers.get_calls == ['container-123'] def test_runtime_list_active_sessions_reads_valid_labeled_containers( tmp_path: Path, ) -> None: config = build_config(tmp_path) containers = FakeContainers() expires_at = datetime(2026, 4, 2, 12, 5, tzinfo=UTC) containers.list_result = [ FakeListedContainer( 'container-123', labels={ 'session_id': str(SESSION_ID), 'chat_id': str(CHAT_ID), 'expires_at': expires_at.isoformat(), 'agent_id': AGENT_ID, 'volume_host_path': str(tmp_path / 'request-volume'), 'endpoint_port': '8000', }, created_at='2026-04-02T12:00:00Z', ), FakeListedContainer( 'container-bad', labels={ 'chat_id': str(CHAT_ID), 'expires_at': expires_at.isoformat(), }, created_at='2026-04-02T12:01:00Z', ), ] runtime = build_runtime(config, containers) sessions = runtime.list_active_sessions() assert sessions == [ SandboxSession( session_id=SESSION_ID, chat_id=CHAT_ID, container_id='container-123', status=SandboxStatus.RUNNING, created_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC), expires_at=expires_at, agent_id=AGENT_ID, volume_host_path=str(tmp_path / 'request-volume'), endpoint=SandboxEndpoint(ip='172.20.0.8', port=8000), ) ] assert containers.list_calls == [ { 'filters': { 'label': [ 'session_id', 'chat_id', 'expires_at', 'agent_id', 'volume_host_path', 'endpoint_port', ] } } ] def test_runtime_list_active_records_observability(tmp_path: Path) -> None: config = build_config(tmp_path) containers = FakeContainers() expires_at = datetime(2026, 4, 2, 12, 5, tzinfo=UTC) containers.list_result = [ FakeListedContainer( 'container-123', labels={ 'session_id': str(SESSION_ID), 'chat_id': str(CHAT_ID), 'expires_at': expires_at.isoformat(), 'agent_id': AGENT_ID, 'volume_host_path': str(tmp_path / 'request-volume'), 'endpoint_port': '8000', }, created_at='2026-04-02T12:00:00Z', ), FakeListedContainer( 'container-bad', labels={ 'chat_id': str(CHAT_ID), 'expires_at': expires_at.isoformat(), }, created_at='2026-04-02T12:01:00Z', ), ] metrics = RecordingMetrics() tracer = RecordingTracer() runtime = DockerSandboxRuntime( config, FakeDockerClient(containers), metrics, tracer, ) sessions = runtime.list_active_sessions() assert len(sessions) == 1 duration_call = _find_record_call( metrics, 'sandbox.runtime.list_active.duration_ms', attrs={'operation': 'list_active', 'result': 'listed'}, ) assert duration_call[1] >= 0 span = _find_span( tracer, 'adapter.docker.list_active_sandboxes', span_attrs={ 'sandbox.container_count': 2, 'sandbox.active_count': 1, 'sandbox.result': 'listed', }, ) assert not span.errors def test_runtime_list_active_error_records_observability(tmp_path: Path) -> None: config = build_config(tmp_path) containers = FakeContainers() containers.list_error = DockerException('boom') metrics = RecordingMetrics() tracer = RecordingTracer() runtime = DockerSandboxRuntime( config, FakeDockerClient(containers), metrics, tracer, ) with pytest.raises(SandboxError) as excinfo: runtime.list_active_sessions() assert str(excinfo.value) == 'sandbox_list_failed' _find_increment_call( metrics, 'sandbox.runtime.error.total', attrs={'operation': 'list_active', 'error.type': 'DockerException'}, ) duration_call = _find_record_call( metrics, 'sandbox.runtime.list_active.duration_ms', attrs={'operation': 'list_active', 'result': 'error'}, ) assert duration_call[1] >= 0 span = _find_span( tracer, 'adapter.docker.list_active_sandboxes', span_attrs={'sandbox.result': 'error'}, ) assert isinstance(excinfo.value.__cause__, DockerException) assert excinfo.value in span.errors