корректная обработка ChatBusyError для WS эндпоинта. Docstring
This commit is contained in:
parent
69ec28037a
commit
58494ddea6
5 changed files with 52 additions and 15 deletions
|
|
@ -1,3 +1,3 @@
|
|||
from src.agent.service import AgentService, AgentChat
|
||||
from src.agent.service import AgentService, AgentChat, ChatBusyError
|
||||
|
||||
__all__ = ["AgentService", "AgentChat"]
|
||||
__all__ = ["AgentService", "AgentChat", "ChatBusyError"]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import AsyncIterator, AsyncContextManager, Self
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
|
||||
from src.agent.base import create_agent
|
||||
from lambda_agent_api.server import (
|
||||
|
|
@ -8,11 +8,22 @@ from lambda_agent_api.server import (
|
|||
)
|
||||
|
||||
|
||||
class ChatInUseError(Exception):
|
||||
class ChatBusyError(Exception):
|
||||
"""
|
||||
Чат занят в другом блоке ``with``
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AgentChat(AsyncContextManager[Self]):
|
||||
"""
|
||||
Объект для работы с конкретным чатом.
|
||||
В то же время является своеобразным 'lock', который позволяет "захватывать" работу с чатом (Mutex).
|
||||
Для контроля доступа используется ``with``.
|
||||
Нельзя войти в блок ``with`` с определенным ``chat_id``, если он уже используется в другом блоке.
|
||||
Перед вызовом любых методов (``astream`` и т. д.) необходимо войти в блок ``with``.
|
||||
Объект получается из AgentService.chat().
|
||||
"""
|
||||
chat_id: int
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -21,7 +32,7 @@ class AgentChat(AsyncContextManager[Self]):
|
|||
|
||||
|
||||
class AgentService:
|
||||
_instance = None
|
||||
_instance = None # синглтон
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
|
|
@ -30,7 +41,11 @@ class AgentService:
|
|||
return cls._instance
|
||||
|
||||
class __AgentChat(AgentChat):
|
||||
__locks: set[int] = set()
|
||||
"""
|
||||
Своеобразная реализация Mutex'а. Служит прослойкой до методов AgentService, но подставляет в них 'захваченный' chat_id.
|
||||
"""
|
||||
|
||||
__locks: set[int] = set() # чаты, которые уже "взяты"
|
||||
|
||||
def __init__(self, service: AgentService, chat_id: int) -> None:
|
||||
self.__chat_id = chat_id
|
||||
|
|
@ -43,7 +58,7 @@ class AgentService:
|
|||
|
||||
async def __aenter__(self):
|
||||
if self.__chat_id in self.__locks:
|
||||
raise ChatInUseError()
|
||||
raise ChatBusyError()
|
||||
|
||||
self.__locks.add(self.__chat_id)
|
||||
return self
|
||||
|
|
@ -51,13 +66,16 @@ class AgentService:
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
self.__locks.remove(self.__chat_id)
|
||||
|
||||
async def astream(self, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||
def astream(self, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||
if not self.__chat_id in self.__locks:
|
||||
raise RuntimeError("Chat must be used in `with` statement")
|
||||
|
||||
return self.__service._AgentService__astream(self.__chat_id, text)
|
||||
|
||||
def chat(self, chat_id: int) -> AgentChat:
|
||||
"""
|
||||
Возвращает объект чата с заданным ID. Не проверяет Mutex.
|
||||
"""
|
||||
return self.__AgentChat(self, chat_id)
|
||||
|
||||
async def __astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Annotated, AsyncGenerator
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, WebSocketException, status
|
||||
|
||||
from src.agent import AgentService, AgentChat
|
||||
from src.agent import AgentService, AgentChat, ChatBusyError
|
||||
|
||||
|
||||
def get_agent_service() -> AgentService:
|
||||
|
|
@ -12,3 +12,21 @@ async def get_chat(service: Annotated[AgentService, Depends(get_agent_service)],
|
|||
chat_id: int) -> AsyncGenerator[AgentChat]:
|
||||
async with service.chat(chat_id) as chat:
|
||||
yield chat
|
||||
|
||||
|
||||
async def get_chat_ws(service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
chat_id: int) -> AsyncGenerator[AgentChat]:
|
||||
"""
|
||||
Версия ``get_chat`` для использования в WS эндпоинтах.
|
||||
Ловит некоторые исключения (ChatBusyError) и оборачивает их в корректную WS ошибку.
|
||||
Необходимо, т. к. глобальный exception handler в FastAPI предназначен для HTTP.\n
|
||||
- ``ChatBusyError`` -> ``WebSocketException(status.WS_1008_POLICY_VIOLATION, reason=str(e))``
|
||||
"""
|
||||
try:
|
||||
gen = get_chat(service, chat_id)
|
||||
yield await gen.__anext__()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
except ChatBusyError as e:
|
||||
raise WebSocketException(status.WS_1008_POLICY_VIOLATION,
|
||||
reason=str(e))
|
||||
|
|
|
|||
|
|
@ -11,16 +11,17 @@ from lambda_agent_api.server import (
|
|||
from lambda_agent_api.client import ClientMessage, MsgUserMessage
|
||||
|
||||
from src.agent import AgentChat
|
||||
from src.api.dependencies import get_chat
|
||||
from src.api.dependencies import get_chat_ws
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.websocket("/agent_ws/{chat_id}/")
|
||||
@router.websocket("/v1/agent_ws/{chat_id}/")
|
||||
async def websocket_endpoint(
|
||||
ws: WebSocket,
|
||||
chat: Annotated[AgentChat, Depends(get_chat)],
|
||||
# важно использовать именно _ws вариант, чтобы корректно обрабатывались исключения
|
||||
chat: Annotated[AgentChat, Depends(get_chat_ws)],
|
||||
):
|
||||
await ws.accept()
|
||||
await ws.send_text(MsgStatus().model_dump_json())
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ from contextlib import asynccontextmanager
|
|||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from src.api.dependencies import get_agent_service
|
||||
from src.api.external import router as ws_router
|
||||
from src.agent import AgentService
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
get_agent_service()
|
||||
AgentService() # инициализируем синглтон
|
||||
yield
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue