import asyncio import time from typing import Any from api.clients.browser_rpc_contracts import BrowserRpcError, BrowserRpcRunner from api.domain.task_status import TaskStatus from api.repositories.task_store import TaskRecord, TaskStore from api.services.browser_runtime_manager import cleanup_browser_runtime, ensure_browser_runtime class TaskService: def __init__( self, store: TaskStore, rpc_client: BrowserRpcRunner, max_concurrency: int, rpc_timeout_cap: float | None = None, ) -> None: self._store = store self._rpc_client = rpc_client self._semaphore = asyncio.Semaphore(max_concurrency) self._rpc_timeout_cap = rpc_timeout_cap self._background_tasks: set[asyncio.Task[None]] = set() self._task_by_run_id: dict[str, asyncio.Task[None]] = {} async def submit_task(self, task: str, timeout: int, metadata: dict | None) -> TaskRecord: record = await self.create_run(thread_id="default", user_input=task, timeout=timeout, metadata=metadata) return record async def create_run(self, thread_id: str, user_input: str, timeout: int, metadata: dict | None) -> TaskRecord: record = await self._store.create(task=user_input, timeout=timeout, metadata=metadata, thread_id=thread_id) background_task = asyncio.create_task(self._worker(record.task_id)) self._background_tasks.add(background_task) background_task.add_done_callback(self._background_tasks.discard) self._task_by_run_id[record.task_id] = background_task def _cleanup(_: asyncio.Task[None]) -> None: self._task_by_run_id.pop(record.task_id, None) background_task.add_done_callback(_cleanup) return record async def get_task(self, task_id: str) -> TaskRecord | None: return await self._store.get(task_id) async def get_run(self, run_id: str) -> TaskRecord | None: return await self.get_task(run_id) async def list_thread_runs(self, thread_id: str) -> list[TaskRecord]: return await self._store.list_by_thread(thread_id) async def cancel_run(self, run_id: str) -> TaskRecord | None: rec = await self._store.set_cancel_requested(run_id) if rec is None: return None if rec.status == TaskStatus.cancelled: await self._store.publish(run_id, self._event(run_id, "cancelled", {"status": rec.status.value})) return rec task = self._task_by_run_id.get(run_id) if task is not None and not task.done(): task.cancel() return rec async def delete_run(self, run_id: str) -> tuple[bool, bool]: return await self._store.delete_if_finished(run_id) async def wait_run(self, run_id: str, timeout: float | None = None) -> TaskRecord | None: rec = await self._store.get(run_id) if rec is None: return None if rec.status not in (TaskStatus.queued, TaskStatus.running): return rec try: if timeout is None: await rec.done_event.wait() else: await asyncio.wait_for(rec.done_event.wait(), timeout=timeout) except asyncio.TimeoutError: return await self._store.get(run_id) return await self._store.get(run_id) async def subscribe_run_stream(self, run_id: str): return await self._store.subscribe(run_id) async def unsubscribe_run_stream(self, run_id: str, queue) -> None: await self._store.unsubscribe(run_id, queue) async def close(self) -> None: if not self._background_tasks: return for task in list(self._background_tasks): task.cancel() await asyncio.gather(*self._background_tasks, return_exceptions=True) self._background_tasks.clear() self._task_by_run_id.clear() async def _worker(self, task_id: str) -> None: rec = await self._store.set_running(task_id) if rec is None: return if rec.status == TaskStatus.cancelled: return await self._store.publish(task_id, self._event(task_id, "started", {"status": TaskStatus.running.value})) async with self._semaphore: runtime: dict[str, str] | None = None try: if rec.cancel_requested: await self._store.set_cancelled(task_id) await self._store.publish(task_id, self._event(task_id, "cancelled", {"status": TaskStatus.cancelled.value})) return runtime = await asyncio.to_thread( ensure_browser_runtime, task_id=task_id, metadata=rec.metadata, thread_id=rec.thread_id, ) rpc_timeout = float(rec.timeout) if self._rpc_timeout_cap is not None: rpc_timeout = min(rpc_timeout, self._rpc_timeout_cap) raw = await asyncio.wait_for( self._rpc_client.run(task=rec.task, timeout_sec=rpc_timeout, rpc_url=runtime.get("rpc_url")), timeout=float(rec.timeout) + 5, ) raw = self._with_runtime_metadata(raw, runtime) success = bool(raw.get("success")) await self._store.set_done( task_id=task_id, success=success, raw_response=raw, error=None, result=raw.get("result") if isinstance(raw, dict) else None, history=self._extract_history(raw), ) done = await self._store.get(task_id) if done is not None: await self._publish_history_events(done) await self._store.publish( task_id, self._event(task_id, "completed" if success else "failed", { "status": done.status.value, "output": done.result, "error": done.error, }), ) except asyncio.CancelledError: await self._store.set_cancelled(task_id) await self._store.publish(task_id, self._event(task_id, "cancelled", {"status": TaskStatus.cancelled.value})) raise except asyncio.TimeoutError: await self._store.set_done( task_id=task_id, success=False, raw_response=None, error="Timeout exceeded", history=None, ) failed = await self._store.get(task_id) if failed is not None: await self._store.publish(task_id, self._event(task_id, "failed", { "status": failed.status.value, "error": failed.error, })) except BrowserRpcError as exc: await self._store.set_done( task_id=task_id, success=False, raw_response=None, error=str(exc), history=None, ) failed = await self._store.get(task_id) if failed is not None: await self._store.publish(task_id, self._event(task_id, "failed", { "status": failed.status.value, "error": failed.error, })) except Exception as exc: await self._store.set_done( task_id=task_id, success=False, raw_response=None, error=f"Internal error: {exc}", history=None, ) failed = await self._store.get(task_id) if failed is not None: await self._store.publish(task_id, self._event(task_id, "failed", { "status": failed.status.value, "error": failed.error, })) finally: try: await asyncio.to_thread( cleanup_browser_runtime, task_id=task_id, metadata=rec.metadata, thread_id=rec.thread_id, ) except Exception: pass async def _publish_history_events(self, rec: TaskRecord) -> None: for index, item in enumerate(rec.history, start=1): await self._store.publish( rec.task_id, self._event(rec.task_id, "output", { "step": item.get("step", index), "kind": item.get("kind") or item.get("type") or "system", "content": item.get("content"), "data": item.get("data") if isinstance(item.get("data"), dict) else {}, }), ) @staticmethod def _event(run_id: str, event: str, data: dict[str, Any]) -> dict[str, Any]: return { "run_id": run_id, "event": event, "ts": time.time(), "data": data, } @staticmethod def _extract_history(raw: dict | None) -> list[dict]: if not isinstance(raw, dict): return [] events = raw.get("history") if not isinstance(events, list): return [] normalized: list[dict] = [] for event in events: if isinstance(event, dict): normalized.append(event) return normalized @staticmethod def _with_runtime_metadata(raw: dict[str, Any], runtime: dict[str, str] | None) -> dict[str, Any]: if not isinstance(raw, dict) or not runtime: return raw enriched = dict(raw) browser_view = runtime.get("browser_view") if browser_view and not enriched.get("browser_view"): enriched["browser_view"] = browser_view enriched["isolation_mode"] = runtime.get("isolation_mode", "shared") owner_hash = runtime.get("owner_hash") if owner_hash: enriched["owner_hash"] = owner_hash return enriched