import time import uuid from asyncio import Event, Lock, Queue from dataclasses import dataclass, field from typing import Any from api.domain.task_status import TaskStatus @dataclass class TaskRecord: task_id: str thread_id: str task: str timeout: int metadata: dict[str, Any] | None status: TaskStatus = TaskStatus.queued create_at: float = field(default_factory=time.time) started_at: float | None = None finished_at: float | None = None result: str | None = None error: str | None = None raw_response: dict[str, Any] | None = None history: list[dict[str, Any]] = field(default_factory=list) cancel_requested: bool = False done_event: Event = field(default_factory=Event) @property def execution_time(self) -> float: if self.started_at is None: return 0 end = self.finished_at if self.finished_at is not None else time.time() return max(0, end - self.started_at) class TaskStore: def __init__(self) -> None: self._lock = Lock() self._tasks: dict[str, TaskRecord] = {} self._thread_index: dict[str, list[str]] = {} self._subscribers: dict[str, set[Queue[dict[str, Any]]]] = {} async def create( self, task: str, timeout: int, metadata: dict[str, Any] | None, thread_id: str = "default", ) -> TaskRecord: task_id = uuid.uuid4().hex rec = TaskRecord(task_id=task_id, thread_id=thread_id, task=task, timeout=timeout, metadata=metadata) async with self._lock: self._tasks[task_id] = rec self._thread_index.setdefault(thread_id, []).append(task_id) self._subscribers.setdefault(task_id, set()) return rec async def list_by_thread(self, thread_id: str) -> list[TaskRecord]: async with self._lock: ids = list(self._thread_index.get(thread_id, [])) return [self._tasks[item] for item in ids if item in self._tasks] async def get(self, task_id: str) -> TaskRecord | None: async with self._lock: return self._tasks.get(task_id) async def set_running(self, task_id: str) -> TaskRecord | None: async with self._lock: rec = self._tasks.get(task_id) if rec is None: return None if rec.status == TaskStatus.cancelled: return rec rec.status = TaskStatus.running rec.started_at = time.time() return rec async def set_done( self, task_id: str, success: bool, raw_response: dict[str, Any] | None, error: str | None, result: str | None = None, history: list[dict[str, Any]] | None = None, ) -> TaskRecord | None: async with self._lock: rec = self._tasks.get(task_id) if rec is None: return None rec.finished_at = time.time() rec.raw_response = raw_response rec.error = error if error is not None else ( raw_response.get("error") if isinstance(raw_response, dict) else None) rec.result = result if result is not None else ( raw_response.get("result") if isinstance(raw_response, dict) else None) rec.history = list(history or []) rec.status = TaskStatus.succeeded if success else TaskStatus.failed rec.done_event.set() return rec async def set_cancel_requested(self, task_id: str) -> TaskRecord | None: async with self._lock: rec = self._tasks.get(task_id) if rec is None: return None rec.cancel_requested = True if rec.status == TaskStatus.queued: rec.status = TaskStatus.cancelled rec.finished_at = time.time() rec.error = "Cancelled by user" rec.done_event.set() return rec async def set_cancelled(self, task_id: str, error: str = "Cancelled by user") -> TaskRecord | None: async with self._lock: rec = self._tasks.get(task_id) if rec is None: return None if rec.status in (TaskStatus.succeeded, TaskStatus.failed, TaskStatus.cancelled): return rec rec.status = TaskStatus.cancelled rec.finished_at = time.time() rec.error = error rec.done_event.set() return rec async def delete_if_finished(self, task_id: str) -> tuple[bool, bool]: async with self._lock: rec = self._tasks.get(task_id) if rec is None: return False, False if rec.status in (TaskStatus.queued, TaskStatus.running): return True, False del self._tasks[task_id] thread_list = self._thread_index.get(rec.thread_id, []) if task_id in thread_list: thread_list.remove(task_id) self._subscribers.pop(task_id, None) return True, True async def subscribe(self, task_id: str) -> Queue[dict[str, Any]] | None: queue: Queue[dict[str, Any]] = Queue() async with self._lock: if task_id not in self._tasks: return None self._subscribers.setdefault(task_id, set()).add(queue) return queue async def unsubscribe(self, task_id: str, queue: Queue[dict[str, Any]]) -> None: async with self._lock: subscribers = self._subscribers.get(task_id) if subscribers is not None: subscribers.discard(queue) async def publish(self, task_id: str, event: dict[str, Any]) -> None: async with self._lock: subscribers = list(self._subscribers.get(task_id, set())) for queue in subscribers: try: queue.put_nowait(event) except Exception: continue