master/adapter/docker/runtime.py

454 lines
16 KiB
Python

import time
from datetime import datetime
from pathlib import Path
from uuid import UUID
from docker import DockerClient
from docker.errors import DockerException, NotFound
from docker.types import Mount
from adapter.config.model import SandboxConfig
from domain.error import SandboxError, SandboxStartError
from domain.sandbox import SandboxEndpoint, SandboxSession, SandboxStatus
from usecase.interface import Metrics, SandboxRuntime, Span, Tracer
SANDBOX_LABELS = (
'session_id',
'chat_id',
'expires_at',
'agent_id',
'volume_host_path',
'endpoint_port',
)
class DockerSandboxRuntime(SandboxRuntime):
def __init__(
self,
config: SandboxConfig,
client: DockerClient,
metrics: Metrics,
tracer: Tracer,
) -> None:
self._config = config
self._client = client
self._metrics = metrics
self._tracer = tracer
def create(
self,
*,
session_id: UUID,
chat_id: UUID,
agent_id: str,
volume_host_path: str,
created_at: datetime,
expires_at: datetime,
) -> SandboxSession:
started_at = time.perf_counter()
result = 'error'
with self._tracer.start_span(
'adapter.docker.create_sandbox',
attrs={
'chat.id': str(chat_id),
'session.id': str(session_id),
},
) as span:
try:
try:
chat_path = self._chat_path(chat_id)
volume_path = self._request_host_path(volume_host_path)
dependencies_path = self._readonly_host_path(
self._config.dependencies_host_path
)
lambda_tools_path = self._readonly_host_path(
self._config.lambda_tools_host_path
)
chat_path.mkdir(parents=True, exist_ok=True)
container = self._client.containers.run(
self._config.image,
detach=True,
environment={**self._config.extra_env, 'AGENT_ID': agent_id},
labels=self._labels(
session_id,
chat_id,
expires_at,
agent_id,
str(volume_path),
),
mounts=self._mounts(
chat_path,
volume_path,
dependencies_path,
lambda_tools_path,
),
network=self._config.network_name,
)
try:
container_id = str(getattr(container, 'id', '')).strip()
if not container_id:
raise ValueError('invalid container id')
endpoint = self._endpoint_from_container(container)
except (DockerException, OSError, ValueError) as exc:
self._remove_created_container(container, str(chat_id), exc)
raise SandboxStartError(str(chat_id), str(exc)) from exc
except SandboxStartError:
raise
except (DockerException, OSError, ValueError) as exc:
raise SandboxStartError(str(chat_id), str(exc)) from exc
result = 'created'
span.set_attribute('container.id', container_id)
span.set_attribute('agent.id', agent_id)
span.set_attribute('sandbox.endpoint.ip', endpoint.ip)
span.set_attribute('sandbox.endpoint.port', endpoint.port)
span.set_attribute('sandbox.result', result)
return SandboxSession(
session_id=session_id,
chat_id=chat_id,
container_id=container_id,
status=SandboxStatus.RUNNING,
created_at=created_at,
expires_at=expires_at,
agent_id=agent_id,
volume_host_path=str(volume_path),
endpoint=endpoint,
)
except Exception as exc:
span.set_attribute('sandbox.result', result)
span.record_error(exc)
self._metrics.increment(
'sandbox.runtime.error.total',
attrs=_runtime_error_metric_attrs('create', _error_type(exc)),
)
raise
finally:
self._metrics.record(
'sandbox.runtime.create.duration_ms',
_duration_ms(started_at),
attrs=_runtime_metric_attrs('create', result),
)
def stop(self, container_id: str) -> None:
started_at = time.perf_counter()
result = 'error'
with self._tracer.start_span(
'adapter.docker.stop_sandbox',
attrs={'container.id': container_id},
) as span:
try:
container = self._client.containers.get(container_id)
_set_span_container_attrs(span, container)
container.stop()
result = 'stopped'
span.set_attribute('sandbox.result', result)
except NotFound:
result = 'not_found'
span.set_attribute('sandbox.result', result)
return
except DockerException as exc:
span.set_attribute('sandbox.result', result)
span.record_error(exc)
self._metrics.increment(
'sandbox.runtime.error.total',
attrs=_runtime_error_metric_attrs('stop', type(exc).__name__),
)
raise SandboxError('sandbox_stop_failed') from exc
finally:
self._metrics.record(
'sandbox.runtime.stop.duration_ms',
_duration_ms(started_at),
attrs=_runtime_metric_attrs('stop', result),
)
def delete(self, container_id: str) -> None:
started_at = time.perf_counter()
result = 'error'
with self._tracer.start_span(
'adapter.docker.delete_sandbox',
attrs={'container.id': container_id},
) as span:
try:
container = self._client.containers.get(container_id)
_set_span_container_attrs(span, container)
container.remove(force=True)
result = 'deleted'
span.set_attribute('sandbox.result', result)
except NotFound:
result = 'not_found'
span.set_attribute('sandbox.result', result)
return
except DockerException as exc:
span.set_attribute('sandbox.result', result)
span.record_error(exc)
self._metrics.increment(
'sandbox.runtime.error.total',
attrs=_runtime_error_metric_attrs('delete', type(exc).__name__),
)
raise SandboxError('sandbox_delete_failed') from exc
finally:
self._metrics.record(
'sandbox.runtime.delete.duration_ms',
_duration_ms(started_at),
attrs=_runtime_metric_attrs('delete', result),
)
def list_active_sessions(self) -> list[SandboxSession]:
started_at = time.perf_counter()
result = 'error'
with self._tracer.start_span(
'adapter.docker.list_active_sandboxes',
) as span:
try:
try:
containers = self._client.containers.list(
filters={'label': list(SANDBOX_LABELS)}
)
except DockerException as exc:
raise SandboxError('sandbox_list_failed') from exc
sessions: list[SandboxSession] = []
for container in containers:
session = self._session_from_container(container)
if session is None:
continue
sessions.append(session)
result = 'listed'
span.set_attribute('sandbox.container_count', len(containers))
span.set_attribute('sandbox.active_count', len(sessions))
span.set_attribute('sandbox.result', result)
return sessions
except Exception as exc:
span.set_attribute('sandbox.result', result)
span.record_error(exc)
self._metrics.increment(
'sandbox.runtime.error.total',
attrs=_runtime_error_metric_attrs('list_active', _error_type(exc)),
)
raise
finally:
self._metrics.record(
'sandbox.runtime.list_active.duration_ms',
_duration_ms(started_at),
attrs=_runtime_metric_attrs('list_active', result),
)
def _labels(
self,
session_id: UUID,
chat_id: UUID,
expires_at: datetime,
agent_id: str,
volume_host_path: str,
) -> dict[str, str]:
return {
'session_id': str(session_id),
'chat_id': str(chat_id),
'expires_at': expires_at.isoformat(),
'agent_id': agent_id,
'volume_host_path': volume_host_path,
'endpoint_port': str(self._config.agent_service_port),
}
def _mounts(
self,
chat_path: Path,
volume_path: Path,
dependencies_path: Path,
lambda_tools_path: Path,
) -> list[Mount]:
return [
Mount(
target=self._config.chat_mount_path,
source=str(chat_path),
type='bind',
),
Mount(
target=self._config.dependencies_mount_path,
source=str(dependencies_path),
type='bind',
read_only=True,
),
Mount(
target=self._config.lambda_tools_mount_path,
source=str(lambda_tools_path),
type='bind',
read_only=True,
),
Mount(
target=self._config.volume_mount_path,
source=str(volume_path),
type='bind',
),
]
def _chat_path(self, chat_id: UUID) -> Path:
chats_root = self._host_path(self._config.chats_root)
chat_path = (chats_root / str(chat_id)).resolve(strict=False)
if not chat_path.is_relative_to(chats_root):
raise ValueError('invalid chat path')
return chat_path
def _readonly_host_path(self, path_value: str) -> Path:
host_path = self._host_path(path_value)
if not host_path.exists():
raise ValueError('invalid host path')
return host_path
def _request_host_path(self, path_value: str) -> Path:
host_path = Path(path_value).expanduser()
if not host_path.is_absolute():
raise ValueError('invalid host path')
return host_path.resolve(strict=False)
def _remove_created_container(
self,
container: object,
chat_id: str,
error: Exception,
) -> None:
remove = getattr(container, 'remove', None)
if not callable(remove):
raise SandboxStartError(chat_id) from error
try:
remove(force=True)
except NotFound:
return
except Exception as exc:
raise SandboxStartError(chat_id) from exc
def _session_from_container(self, container: object) -> SandboxSession | None:
container_id = str(getattr(container, 'id', '')).strip()
labels = getattr(container, 'labels', None)
if not container_id or not isinstance(labels, dict):
return None
try:
session_id = UUID(labels['session_id'])
chat_id = UUID(labels['chat_id'])
agent_id = labels['agent_id']
volume_host_path = labels['volume_host_path']
endpoint_port = int(labels['endpoint_port'])
if not isinstance(agent_id, str) or not isinstance(volume_host_path, str):
raise ValueError('invalid sandbox labels')
if not Path(volume_host_path).is_absolute() or endpoint_port <= 0:
raise ValueError('invalid sandbox labels')
endpoint = self._endpoint_from_container(container, endpoint_port)
created_at = self._container_created_at(container)
expires_at = _parse_datetime(labels['expires_at'])
except (KeyError, TypeError, ValueError):
return None
return SandboxSession(
session_id=session_id,
chat_id=chat_id,
container_id=container_id,
status=SandboxStatus.RUNNING,
created_at=created_at,
expires_at=expires_at,
agent_id=agent_id,
volume_host_path=volume_host_path,
endpoint=endpoint,
)
def _container_created_at(self, container: object) -> datetime:
attrs = self._container_attrs(container)
raw_created_at = attrs.get('Created')
if not isinstance(raw_created_at, str):
raise ValueError('invalid created_at')
return _parse_datetime(raw_created_at)
def _endpoint_from_container(
self,
container: object,
port: int | None = None,
) -> SandboxEndpoint:
attrs = self._container_attrs(container)
network_settings = attrs.get('NetworkSettings')
if not isinstance(network_settings, dict):
raise ValueError('invalid endpoint')
networks = network_settings.get('Networks')
if not isinstance(networks, dict):
raise ValueError('invalid endpoint')
network = networks.get(self._config.network_name)
if not isinstance(network, dict):
raise ValueError('invalid endpoint')
ip = network.get('IPAddress')
if not isinstance(ip, str) or not ip:
raise ValueError('invalid endpoint')
endpoint_port = self._config.agent_service_port if port is None else port
return SandboxEndpoint(ip=ip, port=endpoint_port)
def _container_attrs(self, container: object) -> dict[str, object]:
reload_container = getattr(container, 'reload', None)
if callable(reload_container):
reload_container()
attrs = getattr(container, 'attrs', None)
if not isinstance(attrs, dict):
raise ValueError('invalid container attrs')
return attrs
def _host_path(self, path_value: str) -> Path:
return Path(path_value).expanduser().resolve(strict=False)
def _parse_datetime(value: str) -> datetime:
normalized = f'{value[:-1]}+00:00' if value.endswith('Z') else value
return datetime.fromisoformat(normalized)
def _duration_ms(started_at: float) -> float:
return (time.perf_counter() - started_at) * 1000
def _runtime_metric_attrs(operation: str, result: str) -> dict[str, str]:
return {
'operation': operation,
'result': result,
}
def _runtime_error_metric_attrs(
operation: str,
error_type: str,
) -> dict[str, str]:
return {
'operation': operation,
'error.type': error_type,
}
def _error_type(error: Exception) -> str:
if isinstance(error.__cause__, Exception):
return type(error.__cause__).__name__
return type(error).__name__
def _set_span_container_attrs(span: Span, container: object) -> None:
labels = getattr(container, 'labels', None)
if not isinstance(labels, dict):
return
session_id = labels.get('session_id')
if isinstance(session_id, str) and session_id:
span.set_attribute('session.id', session_id)
chat_id = labels.get('chat_id')
if isinstance(chat_id, str) and chat_id:
span.set_attribute('chat.id', chat_id)