add sandbox runtime control endpoints

This commit is contained in:
Азамат Нураев 2026-04-28 21:53:26 +03:00
parent 0ca0bac9bf
commit 1b38bcfeab
17 changed files with 1408 additions and 119 deletions

View file

@ -1,3 +1,4 @@
from dataclasses import replace
from datetime import UTC, datetime, timedelta
from pathlib import Path
from types import TracebackType
@ -13,22 +14,51 @@ from adapter.config.model import SandboxConfig
from adapter.docker.runtime import DockerSandboxRuntime
from adapter.observability.noop import NoopMetrics, NoopTracer
from domain.error import SandboxError, SandboxStartError
from domain.sandbox import SandboxSession, SandboxStatus
from domain.sandbox import SandboxEndpoint, SandboxSession, SandboxStatus
from usecase.interface import Attrs, AttrValue
CHAT_ID = UUID('123e4567-e89b-12d3-a456-426614174000')
NON_CANONICAL_CHAT_ID = '123E4567E89B12D3A456426614174000'
SESSION_ID = UUID('00000000-0000-0000-0000-000000000010')
AGENT_ID = 'agent-alpha'
def _network_attrs(network_name: str = 'sandbox', ip: str = '172.20.0.8') -> dict[str, object]:
return {
'NetworkSettings': {
'Networks': {
network_name: {
'IPAddress': ip,
}
}
}
}
class FakeContainer:
def __init__(self, container_id: str) -> None:
def __init__(
self,
container_id: str,
*,
network_name: str = 'sandbox',
ip: str = '172.20.0.8',
) -> None:
self.id = container_id
self.stop_calls = 0
self.remove_calls: list[dict[str, bool]] = []
self.reload_calls = 0
self.attrs = _network_attrs(network_name, ip)
self.labels: dict[str, str] = {}
def stop(self) -> None:
self.stop_calls += 1
def reload(self) -> None:
self.reload_calls += 1
def remove(self, *, force: bool) -> None:
self.remove_calls.append({'force': force})
class FakeListedContainer(FakeContainer):
def __init__(
@ -37,10 +67,12 @@ class FakeListedContainer(FakeContainer):
*,
labels: dict[str, str],
created_at: str,
network_name: str = 'sandbox',
ip: str = '172.20.0.8',
) -> None:
super().__init__(container_id)
super().__init__(container_id, network_name=network_name, ip=ip)
self.labels = labels
self.attrs = {'Created': created_at}
self.attrs['Created'] = created_at
class FailingStopContainer(FakeListedContainer):
@ -66,8 +98,10 @@ class FailingStopContainer(FakeListedContainer):
class RunKwargs(TypedDict):
detach: bool
environment: dict[str, str]
labels: dict[str, str]
mounts: list[Mount]
network: str
class RunCall(TypedDict):
@ -90,16 +124,20 @@ class FakeContainers:
image: str,
*,
detach: bool,
environment: dict[str, str],
labels: dict[str, str],
mounts: list[Mount],
network: str,
) -> FakeContainer:
self.run_calls.append(
{
'args': (image,),
'kwargs': {
'detach': detach,
'environment': environment,
'labels': labels,
'mounts': mounts,
'network': network,
},
}
)
@ -266,6 +304,8 @@ def _find_record_call(
def build_config(tmp_path: Path) -> SandboxConfig:
return SandboxConfig(
image='sandbox:latest',
network_name='sandbox',
agent_service_port=8000,
ttl_seconds=300,
cleanup_interval_seconds=60,
chats_root=str(tmp_path / 'chats'),
@ -274,6 +314,7 @@ def build_config(tmp_path: Path) -> SandboxConfig:
chat_mount_path='/workspace/chat',
dependencies_mount_path='/workspace/dependencies',
lambda_tools_mount_path='/workspace/lambda-tools',
volume_mount_path='/workspace/volume',
)
@ -303,6 +344,8 @@ def test_runtime_create_applies_mount_policy_and_labels_with_canonical_chat_id(
session = runtime.create(
session_id=SESSION_ID,
chat_id=UUID(NON_CANONICAL_CHAT_ID),
agent_id=AGENT_ID,
volume_host_path=str(tmp_path / 'request-volume'),
created_at=created_at,
expires_at=expires_at,
)
@ -313,15 +356,25 @@ def test_runtime_create_applies_mount_policy_and_labels_with_canonical_chat_id(
assert session.status is SandboxStatus.RUNNING
assert session.created_at == created_at
assert session.expires_at == expires_at
assert session.agent_id == AGENT_ID
assert session.volume_host_path == str(
(tmp_path / 'request-volume').resolve(strict=False)
)
assert session.endpoint == SandboxEndpoint(ip='172.20.0.8', port=8000)
assert (tmp_path / 'chats' / str(CHAT_ID)).is_dir()
call = containers.run_calls[0]
assert call['args'] == ('sandbox:latest',)
assert call['kwargs']['detach'] is True
assert call['kwargs']['environment'] == {'AGENT_ID': AGENT_ID}
assert call['kwargs']['network'] == 'sandbox'
assert call['kwargs']['labels'] == {
'session_id': str(SESSION_ID),
'chat_id': str(CHAT_ID),
'expires_at': expires_at.isoformat(),
'agent_id': AGENT_ID,
'volume_host_path': str((tmp_path / 'request-volume').resolve(strict=False)),
'endpoint_port': '8000',
}
mounts = call['kwargs']['mounts']
@ -344,9 +397,103 @@ def test_runtime_create_applies_mount_policy_and_labels_with_canonical_chat_id(
'Type': 'bind',
'ReadOnly': True,
},
{
'Target': '/workspace/volume',
'Source': str((tmp_path / 'request-volume').resolve(strict=False)),
'Type': 'bind',
'ReadOnly': False,
},
]
def test_runtime_create_uses_configured_network_for_endpoint(tmp_path: Path) -> None:
config = replace(
build_config(tmp_path),
network_name='agent-net',
agent_service_port=9000,
)
(tmp_path / 'dependencies').mkdir()
(tmp_path / 'lambda-tools').mkdir()
containers = FakeContainers(
run_result=FakeContainer(
'container-456',
network_name='agent-net',
ip='10.42.0.7',
)
)
runtime = build_runtime(config, containers)
created_at = datetime(2026, 4, 2, 12, 0, tzinfo=UTC)
expires_at = created_at + timedelta(minutes=5)
session = runtime.create(
session_id=SESSION_ID,
chat_id=CHAT_ID,
agent_id=AGENT_ID,
volume_host_path=str(tmp_path / 'request-volume'),
created_at=created_at,
expires_at=expires_at,
)
assert containers.run_calls[0]['kwargs']['network'] == 'agent-net'
assert session.endpoint == SandboxEndpoint(ip='10.42.0.7', port=9000)
def test_runtime_create_removes_container_when_endpoint_extraction_fails(
tmp_path: Path,
) -> None:
config = build_config(tmp_path)
(tmp_path / 'dependencies').mkdir()
(tmp_path / 'lambda-tools').mkdir()
created_container = FakeContainer(
'container-789',
network_name='unexpected-net',
)
containers = FakeContainers(run_result=created_container)
runtime = build_runtime(config, containers)
with pytest.raises(SandboxStartError) as excinfo:
runtime.create(
session_id=SESSION_ID,
chat_id=CHAT_ID,
agent_id=AGENT_ID,
volume_host_path=str(tmp_path / 'request-volume'),
created_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC),
expires_at=datetime(2026, 4, 2, 12, 5, tzinfo=UTC),
)
assert str(excinfo.value) == 'sandbox_start_failed'
assert containers.run_calls
assert created_container.remove_calls == [{'force': True}]
def test_runtime_create_applies_request_volume_bind_as_rw(tmp_path: Path) -> None:
config = build_config(tmp_path)
(tmp_path / 'dependencies').mkdir()
(tmp_path / 'lambda-tools').mkdir()
containers = FakeContainers()
runtime = build_runtime(config, containers)
created_at = datetime(2026, 4, 2, 12, 0, tzinfo=UTC)
expires_at = created_at + timedelta(minutes=5)
volume_host_path = str(tmp_path / 'request-volume')
runtime.create(
session_id=SESSION_ID,
chat_id=CHAT_ID,
agent_id=AGENT_ID,
volume_host_path=volume_host_path,
created_at=created_at,
expires_at=expires_at,
)
mounts = [dict(mount) for mount in containers.run_calls[0]['kwargs']['mounts']]
assert {
'Target': '/workspace/volume',
'Source': str((tmp_path / 'request-volume').resolve(strict=False)),
'Type': 'bind',
'ReadOnly': False,
} in mounts
def test_runtime_create_records_observability(tmp_path: Path) -> None:
config = build_config(tmp_path)
(tmp_path / 'dependencies').mkdir()
@ -366,6 +513,8 @@ def test_runtime_create_records_observability(tmp_path: Path) -> None:
session = runtime.create(
session_id=SESSION_ID,
chat_id=CHAT_ID,
agent_id=AGENT_ID,
volume_host_path=str(tmp_path / 'request-volume'),
created_at=created_at,
expires_at=expires_at,
)
@ -402,6 +551,8 @@ def test_runtime_create_raises_start_error_when_container_id_is_missing(
runtime.create(
session_id=SESSION_ID,
chat_id=CHAT_ID,
agent_id=AGENT_ID,
volume_host_path=str(tmp_path / 'request-volume'),
created_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC),
expires_at=datetime(2026, 4, 2, 12, 5, tzinfo=UTC),
)
@ -430,6 +581,8 @@ def test_runtime_create_error_records_observability_when_container_id_missing(
runtime.create(
session_id=SESSION_ID,
chat_id=CHAT_ID,
agent_id=AGENT_ID,
volume_host_path=str(tmp_path / 'request-volume'),
created_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC),
expires_at=datetime(2026, 4, 2, 12, 5, tzinfo=UTC),
)
@ -438,7 +591,7 @@ def test_runtime_create_error_records_observability_when_container_id_missing(
_find_increment_call(
metrics,
'sandbox.runtime.error.total',
attrs={'operation': 'create', 'error.type': 'SandboxStartError'},
attrs={'operation': 'create', 'error.type': 'ValueError'},
)
duration_call = _find_record_call(
metrics,
@ -598,6 +751,38 @@ def test_runtime_stop_records_observability_on_success(tmp_path: Path) -> None:
assert stop_error_calls == []
def test_runtime_delete_removes_container_with_force(tmp_path: Path) -> None:
config = build_config(tmp_path)
containers = FakeContainers()
container = FakeListedContainer(
'container-123',
labels={
'session_id': str(SESSION_ID),
'chat_id': str(CHAT_ID),
'expires_at': '2026-04-02T12:05:00+00:00',
},
created_at='2026-04-02T12:00:00Z',
)
containers.get_result = container
runtime = build_runtime(config, containers)
runtime.delete('container-123')
assert containers.get_calls == ['container-123']
assert container.remove_calls == [{'force': True}]
def test_runtime_delete_ignores_missing_container(tmp_path: Path) -> None:
config = build_config(tmp_path)
containers = FakeContainers()
containers.get_result = NotFound('missing')
runtime = build_runtime(config, containers)
runtime.delete('container-123')
assert containers.get_calls == ['container-123']
def test_runtime_list_active_sessions_reads_valid_labeled_containers(
tmp_path: Path,
) -> None:
@ -611,6 +796,9 @@ def test_runtime_list_active_sessions_reads_valid_labeled_containers(
'session_id': str(SESSION_ID),
'chat_id': str(CHAT_ID),
'expires_at': expires_at.isoformat(),
'agent_id': AGENT_ID,
'volume_host_path': str(tmp_path / 'request-volume'),
'endpoint_port': '8000',
},
created_at='2026-04-02T12:00:00Z',
),
@ -635,10 +823,24 @@ def test_runtime_list_active_sessions_reads_valid_labeled_containers(
status=SandboxStatus.RUNNING,
created_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC),
expires_at=expires_at,
agent_id=AGENT_ID,
volume_host_path=str(tmp_path / 'request-volume'),
endpoint=SandboxEndpoint(ip='172.20.0.8', port=8000),
)
]
assert containers.list_calls == [
{'filters': {'label': ['session_id', 'chat_id', 'expires_at']}}
{
'filters': {
'label': [
'session_id',
'chat_id',
'expires_at',
'agent_id',
'volume_host_path',
'endpoint_port',
]
}
}
]
@ -653,6 +855,9 @@ def test_runtime_list_active_records_observability(tmp_path: Path) -> None:
'session_id': str(SESSION_ID),
'chat_id': str(CHAT_ID),
'expires_at': expires_at.isoformat(),
'agent_id': AGENT_ID,
'volume_host_path': str(tmp_path / 'request-volume'),
'endpoint_port': '8000',
},
created_at='2026-04-02T12:00:00Z',
),