BrowserUse_and_ComputerUse_.../api/routes/runs.py

130 lines
4.5 KiB
Python

import asyncio
import json
from typing import AsyncIterator
from fastapi import APIRouter, Depends, HTTPException, Query, Response
from fastapi.responses import JSONResponse, StreamingResponse
from api.contracts.task_schemas import (
RunCreateRequest,
RunListResponse,
RunResponse,
RunStreamEvent,
RunSummaryResponse,
RunWaitResponse,
)
from api.mappers.task_record_mapper import TaskRecordMapper
from api.routes.dependencies import get_task_service
from api.services.protocols import TaskServiceProtocol
router = APIRouter(tags=["runs"])
@router.get("/threads/{thread_id}/runs", response_model=RunListResponse)
async def list_thread_runs(
thread_id: str,
service: TaskServiceProtocol = Depends(get_task_service),
) -> RunListResponse:
runs = await service.list_thread_runs(thread_id)
return TaskRecordMapper.to_thread_run_list(thread_id, runs)
@router.post("/runs", response_model=RunSummaryResponse, status_code=202)
async def create_run(
payload: RunCreateRequest,
service: TaskServiceProtocol = Depends(get_task_service),
) -> RunSummaryResponse:
rec = await service.create_run(
thread_id=payload.thread_id.strip(),
user_input=payload.input.strip(),
timeout=payload.timeout,
metadata=payload.metadata,
)
return TaskRecordMapper.to_run_summary(rec)
@router.get("/runs/{run_id}", response_model=RunResponse)
async def get_run(
run_id: str,
service: TaskServiceProtocol = Depends(get_task_service),
) -> RunResponse:
rec = await service.get_run(run_id)
if rec is None:
raise HTTPException(status_code=404, detail="Run not found")
return TaskRecordMapper.to_run_response(rec)
@router.post("/runs/{run_id}/cancel", response_model=RunSummaryResponse)
async def cancel_run(
run_id: str,
service: TaskServiceProtocol = Depends(get_task_service),
) -> RunSummaryResponse:
rec = await service.cancel_run(run_id)
if rec is None:
raise HTTPException(status_code=404, detail="Run not found")
return TaskRecordMapper.to_run_summary(rec)
@router.delete("/runs/{run_id}", status_code=204)
async def delete_run(
run_id: str,
service: TaskServiceProtocol = Depends(get_task_service),
) -> Response:
exists, deleted = await service.delete_run(run_id)
if not exists:
raise HTTPException(status_code=404, detail="Run not found")
if not deleted:
raise HTTPException(status_code=409, detail="Run is still active. Cancel it first.")
return Response(status_code=204)
@router.get("/runs/{run_id}/wait", response_model=RunWaitResponse)
async def wait_run(
run_id: str,
timeout: float | None = Query(default=None, ge=0),
service: TaskServiceProtocol = Depends(get_task_service),
) -> JSONResponse | RunWaitResponse:
rec = await service.wait_run(run_id, timeout=timeout)
if rec is None:
raise HTTPException(status_code=404, detail="Run not found")
if TaskRecordMapper.is_active_status(rec.status):
pending = TaskRecordMapper.to_run_wait(rec)
return JSONResponse(status_code=202, content=pending.model_dump(mode="json"))
return TaskRecordMapper.to_run_wait(rec)
@router.get("/runs/{run_id}/stream")
async def stream_run(
run_id: str,
service: TaskServiceProtocol = Depends(get_task_service),
) -> StreamingResponse:
queue = await service.subscribe_run_stream(run_id)
if queue is None:
raise HTTPException(status_code=404, detail="Run not found")
stream_queue = queue
async def event_stream() -> AsyncIterator[str]:
try:
while True:
try:
item = await asyncio.wait_for(stream_queue.get(), timeout=15)
except asyncio.TimeoutError:
rec = await service.get_run(run_id)
if rec is None:
break
if not TaskRecordMapper.is_active_status(rec.status):
break
yield ": keep-alive\n\n"
continue
payload = RunStreamEvent.model_validate(item).model_dump(mode="json")
yield f"data: {json.dumps(payload, ensure_ascii=False)}\\n\\n"
if payload["event"] in ("completed", "failed", "cancelled"):
break
finally:
await service.unsubscribe_run_stream(run_id, stream_queue)
return StreamingResponse(event_stream(), media_type="text/event-stream")