BrowserUse_and_ComputerUse_.../api/services/task_service.py

260 lines
10 KiB
Python

import asyncio
import time
from typing import Any
from api.clients.browser_rpc_contracts import BrowserRpcError, BrowserRpcRunner
from api.domain.task_status import TaskStatus
from api.repositories.task_store import TaskRecord, TaskStore
from api.services.browser_runtime_manager import cleanup_browser_runtime, ensure_browser_runtime
class TaskService:
def __init__(
self,
store: TaskStore,
rpc_client: BrowserRpcRunner,
max_concurrency: int,
rpc_timeout_cap: float | None = None,
) -> None:
self._store = store
self._rpc_client = rpc_client
self._semaphore = asyncio.Semaphore(max_concurrency)
self._rpc_timeout_cap = rpc_timeout_cap
self._background_tasks: set[asyncio.Task[None]] = set()
self._task_by_run_id: dict[str, asyncio.Task[None]] = {}
async def submit_task(self, task: str, timeout: int, metadata: dict | None) -> TaskRecord:
record = await self.create_run(thread_id="default", user_input=task, timeout=timeout, metadata=metadata)
return record
async def create_run(self, thread_id: str, user_input: str, timeout: int, metadata: dict | None) -> TaskRecord:
record = await self._store.create(task=user_input, timeout=timeout, metadata=metadata, thread_id=thread_id)
background_task = asyncio.create_task(self._worker(record.task_id))
self._background_tasks.add(background_task)
background_task.add_done_callback(self._background_tasks.discard)
self._task_by_run_id[record.task_id] = background_task
def _cleanup(_: asyncio.Task[None]) -> None:
self._task_by_run_id.pop(record.task_id, None)
background_task.add_done_callback(_cleanup)
return record
async def get_task(self, task_id: str) -> TaskRecord | None:
return await self._store.get(task_id)
async def get_run(self, run_id: str) -> TaskRecord | None:
return await self.get_task(run_id)
async def list_thread_runs(self, thread_id: str) -> list[TaskRecord]:
return await self._store.list_by_thread(thread_id)
async def cancel_run(self, run_id: str) -> TaskRecord | None:
rec = await self._store.set_cancel_requested(run_id)
if rec is None:
return None
if rec.status == TaskStatus.cancelled:
await self._store.publish(run_id, self._event(run_id, "cancelled", {"status": rec.status.value}))
return rec
task = self._task_by_run_id.get(run_id)
if task is not None and not task.done():
task.cancel()
return rec
async def delete_run(self, run_id: str) -> tuple[bool, bool]:
return await self._store.delete_if_finished(run_id)
async def wait_run(self, run_id: str, timeout: float | None = None) -> TaskRecord | None:
rec = await self._store.get(run_id)
if rec is None:
return None
if rec.status not in (TaskStatus.queued, TaskStatus.running):
return rec
try:
if timeout is None:
await rec.done_event.wait()
else:
await asyncio.wait_for(rec.done_event.wait(), timeout=timeout)
except asyncio.TimeoutError:
return await self._store.get(run_id)
return await self._store.get(run_id)
async def subscribe_run_stream(self, run_id: str):
return await self._store.subscribe(run_id)
async def unsubscribe_run_stream(self, run_id: str, queue) -> None:
await self._store.unsubscribe(run_id, queue)
async def close(self) -> None:
if not self._background_tasks:
return
for task in list(self._background_tasks):
task.cancel()
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
self._task_by_run_id.clear()
async def _worker(self, task_id: str) -> None:
rec = await self._store.set_running(task_id)
if rec is None:
return
if rec.status == TaskStatus.cancelled:
return
await self._store.publish(task_id, self._event(task_id, "started", {"status": TaskStatus.running.value}))
async with self._semaphore:
runtime: dict[str, str] | None = None
try:
if rec.cancel_requested:
await self._store.set_cancelled(task_id)
await self._store.publish(task_id, self._event(task_id, "cancelled", {"status": TaskStatus.cancelled.value}))
return
runtime = await asyncio.to_thread(
ensure_browser_runtime,
task_id=task_id,
metadata=rec.metadata,
thread_id=rec.thread_id,
)
rpc_timeout = float(rec.timeout)
if self._rpc_timeout_cap is not None:
rpc_timeout = min(rpc_timeout, self._rpc_timeout_cap)
raw = await asyncio.wait_for(
self._rpc_client.run(task=rec.task, timeout_sec=rpc_timeout, rpc_url=runtime.get("rpc_url")),
timeout=float(rec.timeout) + 5,
)
raw = self._with_runtime_metadata(raw, runtime)
success = bool(raw.get("success"))
await self._store.set_done(
task_id=task_id,
success=success,
raw_response=raw,
error=None,
result=raw.get("result") if isinstance(raw, dict) else None,
history=self._extract_history(raw),
)
done = await self._store.get(task_id)
if done is not None:
await self._publish_history_events(done)
await self._store.publish(
task_id,
self._event(task_id, "completed" if success else "failed", {
"status": done.status.value,
"output": done.result,
"error": done.error,
}),
)
except asyncio.CancelledError:
await self._store.set_cancelled(task_id)
await self._store.publish(task_id, self._event(task_id, "cancelled", {"status": TaskStatus.cancelled.value}))
raise
except asyncio.TimeoutError:
await self._store.set_done(
task_id=task_id,
success=False,
raw_response=None,
error="Timeout exceeded",
history=None,
)
failed = await self._store.get(task_id)
if failed is not None:
await self._store.publish(task_id, self._event(task_id, "failed", {
"status": failed.status.value,
"error": failed.error,
}))
except BrowserRpcError as exc:
await self._store.set_done(
task_id=task_id,
success=False,
raw_response=None,
error=str(exc),
history=None,
)
failed = await self._store.get(task_id)
if failed is not None:
await self._store.publish(task_id, self._event(task_id, "failed", {
"status": failed.status.value,
"error": failed.error,
}))
except Exception as exc:
await self._store.set_done(
task_id=task_id,
success=False,
raw_response=None,
error=f"Internal error: {exc}",
history=None,
)
failed = await self._store.get(task_id)
if failed is not None:
await self._store.publish(task_id, self._event(task_id, "failed", {
"status": failed.status.value,
"error": failed.error,
}))
finally:
try:
await asyncio.to_thread(
cleanup_browser_runtime,
task_id=task_id,
metadata=rec.metadata,
thread_id=rec.thread_id,
)
except Exception:
pass
async def _publish_history_events(self, rec: TaskRecord) -> None:
for index, item in enumerate(rec.history, start=1):
await self._store.publish(
rec.task_id,
self._event(rec.task_id, "output", {
"step": item.get("step", index),
"kind": item.get("kind") or item.get("type") or "system",
"content": item.get("content"),
"data": item.get("data") if isinstance(item.get("data"), dict) else {},
}),
)
@staticmethod
def _event(run_id: str, event: str, data: dict[str, Any]) -> dict[str, Any]:
return {
"run_id": run_id,
"event": event,
"ts": time.time(),
"data": data,
}
@staticmethod
def _extract_history(raw: dict | None) -> list[dict]:
if not isinstance(raw, dict):
return []
events = raw.get("history")
if not isinstance(events, list):
return []
normalized: list[dict] = []
for event in events:
if isinstance(event, dict):
normalized.append(event)
return normalized
@staticmethod
def _with_runtime_metadata(raw: dict[str, Any], runtime: dict[str, str] | None) -> dict[str, Any]:
if not isinstance(raw, dict) or not runtime:
return raw
enriched = dict(raw)
browser_view = runtime.get("browser_view")
if browser_view and not enriched.get("browser_view"):
enriched["browser_view"] = browser_view
enriched["isolation_mode"] = runtime.get("isolation_mode", "shared")
owner_hash = runtime.get("owner_hash")
if owner_hash:
enriched["owner_hash"] = owner_hash
return enriched