BrowserUse_and_ComputerUse_.../api/repositories/task_store.py

164 lines
5.9 KiB
Python

import time
import uuid
from asyncio import Event, Lock, Queue
from dataclasses import dataclass, field
from typing import Any
from api.domain.task_status import TaskStatus
@dataclass
class TaskRecord:
task_id: str
thread_id: str
task: str
timeout: int
metadata: dict[str, Any] | None
status: TaskStatus = TaskStatus.queued
create_at: float = field(default_factory=time.time)
started_at: float | None = None
finished_at: float | None = None
result: str | None = None
error: str | None = None
raw_response: dict[str, Any] | None = None
history: list[dict[str, Any]] = field(default_factory=list)
cancel_requested: bool = False
done_event: Event = field(default_factory=Event)
@property
def execution_time(self) -> float:
if self.started_at is None:
return 0
end = self.finished_at if self.finished_at is not None else time.time()
return max(0, end - self.started_at)
class TaskStore:
def __init__(self) -> None:
self._lock = Lock()
self._tasks: dict[str, TaskRecord] = {}
self._thread_index: dict[str, list[str]] = {}
self._subscribers: dict[str, set[Queue[dict[str, Any]]]] = {}
async def create(
self,
task: str,
timeout: int,
metadata: dict[str, Any] | None,
thread_id: str = "default",
) -> TaskRecord:
task_id = uuid.uuid4().hex
rec = TaskRecord(task_id=task_id, thread_id=thread_id, task=task, timeout=timeout, metadata=metadata)
async with self._lock:
self._tasks[task_id] = rec
self._thread_index.setdefault(thread_id, []).append(task_id)
self._subscribers.setdefault(task_id, set())
return rec
async def list_by_thread(self, thread_id: str) -> list[TaskRecord]:
async with self._lock:
ids = list(self._thread_index.get(thread_id, []))
return [self._tasks[item] for item in ids if item in self._tasks]
async def get(self, task_id: str) -> TaskRecord | None:
async with self._lock:
return self._tasks.get(task_id)
async def set_running(self, task_id: str) -> TaskRecord | None:
async with self._lock:
rec = self._tasks.get(task_id)
if rec is None:
return None
if rec.status == TaskStatus.cancelled:
return rec
rec.status = TaskStatus.running
rec.started_at = time.time()
return rec
async def set_done(
self,
task_id: str,
success: bool,
raw_response: dict[str, Any] | None,
error: str | None,
result: str | None = None,
history: list[dict[str, Any]] | None = None,
) -> TaskRecord | None:
async with self._lock:
rec = self._tasks.get(task_id)
if rec is None:
return None
rec.finished_at = time.time()
rec.raw_response = raw_response
rec.error = error if error is not None else (
raw_response.get("error") if isinstance(raw_response, dict) else None)
rec.result = result if result is not None else (
raw_response.get("result") if isinstance(raw_response, dict) else None)
rec.history = list(history or [])
rec.status = TaskStatus.succeeded if success else TaskStatus.failed
rec.done_event.set()
return rec
async def set_cancel_requested(self, task_id: str) -> TaskRecord | None:
async with self._lock:
rec = self._tasks.get(task_id)
if rec is None:
return None
rec.cancel_requested = True
if rec.status == TaskStatus.queued:
rec.status = TaskStatus.cancelled
rec.finished_at = time.time()
rec.error = "Cancelled by user"
rec.done_event.set()
return rec
async def set_cancelled(self, task_id: str, error: str = "Cancelled by user") -> TaskRecord | None:
async with self._lock:
rec = self._tasks.get(task_id)
if rec is None:
return None
if rec.status in (TaskStatus.succeeded, TaskStatus.failed, TaskStatus.cancelled):
return rec
rec.status = TaskStatus.cancelled
rec.finished_at = time.time()
rec.error = error
rec.done_event.set()
return rec
async def delete_if_finished(self, task_id: str) -> tuple[bool, bool]:
async with self._lock:
rec = self._tasks.get(task_id)
if rec is None:
return False, False
if rec.status in (TaskStatus.queued, TaskStatus.running):
return True, False
del self._tasks[task_id]
thread_list = self._thread_index.get(rec.thread_id, [])
if task_id in thread_list:
thread_list.remove(task_id)
self._subscribers.pop(task_id, None)
return True, True
async def subscribe(self, task_id: str) -> Queue[dict[str, Any]] | None:
queue: Queue[dict[str, Any]] = Queue()
async with self._lock:
if task_id not in self._tasks:
return None
self._subscribers.setdefault(task_id, set()).add(queue)
return queue
async def unsubscribe(self, task_id: str, queue: Queue[dict[str, Any]]) -> None:
async with self._lock:
subscribers = self._subscribers.get(task_id)
if subscribers is not None:
subscribers.discard(queue)
async def publish(self, task_id: str, event: dict[str, Any]) -> None:
async with self._lock:
subscribers = list(self._subscribers.get(task_id, set()))
for queue in subscribers:
try:
queue.put_nowait(event)
except Exception:
continue