корректная обработка ChatBusyError для WS эндпоинта. Docstring

This commit is contained in:
Егор Кандрушин 2026-04-19 15:48:34 +03:00
parent 69ec28037a
commit 58494ddea6
5 changed files with 52 additions and 15 deletions

View file

@ -1,3 +1,3 @@
from src.agent.service import AgentService, AgentChat from src.agent.service import AgentService, AgentChat, ChatBusyError
__all__ = ["AgentService", "AgentChat"] __all__ = ["AgentService", "AgentChat", "ChatBusyError"]

View file

@ -1,5 +1,5 @@
from typing import AsyncIterator, AsyncContextManager, Self from typing import AsyncIterator, AsyncContextManager, Self
from abc import ABC, abstractmethod from abc import abstractmethod
from src.agent.base import create_agent from src.agent.base import create_agent
from lambda_agent_api.server import ( from lambda_agent_api.server import (
@ -8,11 +8,22 @@ from lambda_agent_api.server import (
) )
class ChatInUseError(Exception): class ChatBusyError(Exception):
"""
Чат занят в другом блоке ``with``
"""
pass pass
class AgentChat(AsyncContextManager[Self]): class AgentChat(AsyncContextManager[Self]):
"""
Объект для работы с конкретным чатом.
В то же время является своеобразным 'lock', который позволяет "захватывать" работу с чатом (Mutex).
Для контроля доступа используется ``with``.
Нельзя войти в блок ``with`` с определенным ``chat_id``, если он уже используется в другом блоке.
Перед вызовом любых методов (``astream`` и т. д.) необходимо войти в блок ``with``.
Объект получается из AgentService.chat().
"""
chat_id: int chat_id: int
@abstractmethod @abstractmethod
@ -21,7 +32,7 @@ class AgentChat(AsyncContextManager[Self]):
class AgentService: class AgentService:
_instance = None _instance = None # синглтон
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
@ -30,7 +41,11 @@ class AgentService:
return cls._instance return cls._instance
class __AgentChat(AgentChat): class __AgentChat(AgentChat):
__locks: set[int] = set() """
Своеобразная реализация Mutex'а. Служит прослойкой до методов AgentService, но подставляет в них 'захваченный' chat_id.
"""
__locks: set[int] = set() # чаты, которые уже "взяты"
def __init__(self, service: AgentService, chat_id: int) -> None: def __init__(self, service: AgentService, chat_id: int) -> None:
self.__chat_id = chat_id self.__chat_id = chat_id
@ -43,7 +58,7 @@ class AgentService:
async def __aenter__(self): async def __aenter__(self):
if self.__chat_id in self.__locks: if self.__chat_id in self.__locks:
raise ChatInUseError() raise ChatBusyError()
self.__locks.add(self.__chat_id) self.__locks.add(self.__chat_id)
return self return self
@ -51,13 +66,16 @@ class AgentService:
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
self.__locks.remove(self.__chat_id) self.__locks.remove(self.__chat_id)
async def astream(self, text: str) -> AsyncIterator[AgentEventUnion]: def astream(self, text: str) -> AsyncIterator[AgentEventUnion]:
if not self.__chat_id in self.__locks: if not self.__chat_id in self.__locks:
raise RuntimeError("Chat must be used in `with` statement") raise RuntimeError("Chat must be used in `with` statement")
return self.__service._AgentService__astream(self.__chat_id, text) return self.__service._AgentService__astream(self.__chat_id, text)
def chat(self, chat_id: int) -> AgentChat: def chat(self, chat_id: int) -> AgentChat:
"""
Возвращает объект чата с заданным ID. Не проверяет Mutex.
"""
return self.__AgentChat(self, chat_id) return self.__AgentChat(self, chat_id)
async def __astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]: async def __astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]:

View file

@ -1,7 +1,7 @@
from typing import Annotated, AsyncGenerator from typing import Annotated, AsyncGenerator
from fastapi import Depends from fastapi import Depends, WebSocketException, status
from src.agent import AgentService, AgentChat from src.agent import AgentService, AgentChat, ChatBusyError
def get_agent_service() -> AgentService: def get_agent_service() -> AgentService:
@ -12,3 +12,21 @@ async def get_chat(service: Annotated[AgentService, Depends(get_agent_service)],
chat_id: int) -> AsyncGenerator[AgentChat]: chat_id: int) -> AsyncGenerator[AgentChat]:
async with service.chat(chat_id) as chat: async with service.chat(chat_id) as chat:
yield chat yield chat
async def get_chat_ws(service: Annotated[AgentService, Depends(get_agent_service)],
chat_id: int) -> AsyncGenerator[AgentChat]:
"""
Версия ``get_chat`` для использования в WS эндпоинтах.
Ловит некоторые исключения (ChatBusyError) и оборачивает их в корректную WS ошибку.
Необходимо, т. к. глобальный exception handler в FastAPI предназначен для HTTP.\n
- ``ChatBusyError`` -> ``WebSocketException(status.WS_1008_POLICY_VIOLATION, reason=str(e))``
"""
try:
gen = get_chat(service, chat_id)
yield await gen.__anext__()
except StopAsyncIteration:
pass
except ChatBusyError as e:
raise WebSocketException(status.WS_1008_POLICY_VIOLATION,
reason=str(e))

View file

@ -11,16 +11,17 @@ from lambda_agent_api.server import (
from lambda_agent_api.client import ClientMessage, MsgUserMessage from lambda_agent_api.client import ClientMessage, MsgUserMessage
from src.agent import AgentChat from src.agent import AgentChat
from src.api.dependencies import get_chat from src.api.dependencies import get_chat_ws
router = APIRouter() router = APIRouter()
@router.websocket("/agent_ws/{chat_id}/") @router.websocket("/v1/agent_ws/{chat_id}/")
async def websocket_endpoint( async def websocket_endpoint(
ws: WebSocket, ws: WebSocket,
chat: Annotated[AgentChat, Depends(get_chat)], # важно использовать именно _ws вариант, чтобы корректно обрабатывались исключения
chat: Annotated[AgentChat, Depends(get_chat_ws)],
): ):
await ws.accept() await ws.accept()
await ws.send_text(MsgStatus().model_dump_json()) await ws.send_text(MsgStatus().model_dump_json())

View file

@ -2,13 +2,13 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from src.api.dependencies import get_agent_service
from src.api.external import router as ws_router from src.api.external import router as ws_router
from src.agent import AgentService
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
get_agent_service() AgentService() # инициализируем синглтон
yield yield