130 lines
4.5 KiB
Python
130 lines
4.5 KiB
Python
import asyncio
|
|
import json
|
|
from typing import AsyncIterator
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
from api.contracts.task_schemas import (
|
|
RunCreateRequest,
|
|
RunListResponse,
|
|
RunResponse,
|
|
RunStreamEvent,
|
|
RunSummaryResponse,
|
|
RunWaitResponse,
|
|
)
|
|
from api.mappers.task_record_mapper import TaskRecordMapper
|
|
from api.routes.dependencies import get_task_service
|
|
from api.services.protocols import TaskServiceProtocol
|
|
|
|
router = APIRouter(tags=["runs"])
|
|
|
|
|
|
@router.get("/threads/{thread_id}/runs", response_model=RunListResponse)
|
|
async def list_thread_runs(
|
|
thread_id: str,
|
|
service: TaskServiceProtocol = Depends(get_task_service),
|
|
) -> RunListResponse:
|
|
runs = await service.list_thread_runs(thread_id)
|
|
return TaskRecordMapper.to_thread_run_list(thread_id, runs)
|
|
|
|
|
|
@router.post("/runs", response_model=RunSummaryResponse, status_code=202)
|
|
async def create_run(
|
|
payload: RunCreateRequest,
|
|
service: TaskServiceProtocol = Depends(get_task_service),
|
|
) -> RunSummaryResponse:
|
|
rec = await service.create_run(
|
|
thread_id=payload.thread_id.strip(),
|
|
user_input=payload.input.strip(),
|
|
timeout=payload.timeout,
|
|
metadata=payload.metadata,
|
|
)
|
|
return TaskRecordMapper.to_run_summary(rec)
|
|
|
|
|
|
@router.get("/runs/{run_id}", response_model=RunResponse)
|
|
async def get_run(
|
|
run_id: str,
|
|
service: TaskServiceProtocol = Depends(get_task_service),
|
|
) -> RunResponse:
|
|
rec = await service.get_run(run_id)
|
|
if rec is None:
|
|
raise HTTPException(status_code=404, detail="Run not found")
|
|
return TaskRecordMapper.to_run_response(rec)
|
|
|
|
|
|
@router.post("/runs/{run_id}/cancel", response_model=RunSummaryResponse)
|
|
async def cancel_run(
|
|
run_id: str,
|
|
service: TaskServiceProtocol = Depends(get_task_service),
|
|
) -> RunSummaryResponse:
|
|
rec = await service.cancel_run(run_id)
|
|
if rec is None:
|
|
raise HTTPException(status_code=404, detail="Run not found")
|
|
return TaskRecordMapper.to_run_summary(rec)
|
|
|
|
|
|
@router.delete("/runs/{run_id}", status_code=204)
|
|
async def delete_run(
|
|
run_id: str,
|
|
service: TaskServiceProtocol = Depends(get_task_service),
|
|
) -> Response:
|
|
exists, deleted = await service.delete_run(run_id)
|
|
if not exists:
|
|
raise HTTPException(status_code=404, detail="Run not found")
|
|
if not deleted:
|
|
raise HTTPException(status_code=409, detail="Run is still active. Cancel it first.")
|
|
return Response(status_code=204)
|
|
|
|
|
|
@router.get("/runs/{run_id}/wait", response_model=RunWaitResponse)
|
|
async def wait_run(
|
|
run_id: str,
|
|
timeout: float | None = Query(default=None, ge=0),
|
|
service: TaskServiceProtocol = Depends(get_task_service),
|
|
) -> JSONResponse | RunWaitResponse:
|
|
rec = await service.wait_run(run_id, timeout=timeout)
|
|
if rec is None:
|
|
raise HTTPException(status_code=404, detail="Run not found")
|
|
|
|
if TaskRecordMapper.is_active_status(rec.status):
|
|
pending = TaskRecordMapper.to_run_wait(rec)
|
|
return JSONResponse(status_code=202, content=pending.model_dump(mode="json"))
|
|
|
|
return TaskRecordMapper.to_run_wait(rec)
|
|
|
|
|
|
@router.get("/runs/{run_id}/stream")
|
|
async def stream_run(
|
|
run_id: str,
|
|
service: TaskServiceProtocol = Depends(get_task_service),
|
|
) -> StreamingResponse:
|
|
queue = await service.subscribe_run_stream(run_id)
|
|
if queue is None:
|
|
raise HTTPException(status_code=404, detail="Run not found")
|
|
stream_queue = queue
|
|
|
|
async def event_stream() -> AsyncIterator[str]:
|
|
try:
|
|
while True:
|
|
try:
|
|
item = await asyncio.wait_for(stream_queue.get(), timeout=15)
|
|
except asyncio.TimeoutError:
|
|
rec = await service.get_run(run_id)
|
|
if rec is None:
|
|
break
|
|
if not TaskRecordMapper.is_active_status(rec.status):
|
|
break
|
|
yield ": keep-alive\n\n"
|
|
continue
|
|
|
|
payload = RunStreamEvent.model_validate(item).model_dump(mode="json")
|
|
yield f"data: {json.dumps(payload, ensure_ascii=False)}\\n\\n"
|
|
|
|
if payload["event"] in ("completed", "failed", "cancelled"):
|
|
break
|
|
finally:
|
|
await service.unsubscribe_run_stream(run_id, stream_queue)
|
|
|
|
return StreamingResponse(event_stream(), media_type="text/event-stream")
|