From 58494ddea6734b30822ba2be9df3432deacb8ef2 Mon Sep 17 00:00:00 2001 From: MrKan Date: Sun, 19 Apr 2026 15:48:34 +0300 Subject: [PATCH] =?UTF-8?q?=D0=BA=D0=BE=D1=80=D1=80=D0=B5=D0=BA=D1=82?= =?UTF-8?q?=D0=BD=D0=B0=D1=8F=20=D0=BE=D0=B1=D1=80=D0=B0=D0=B1=D0=BE=D1=82?= =?UTF-8?q?=D0=BA=D0=B0=20ChatBusyError=20=D0=B4=D0=BB=D1=8F=20WS=20=D1=8D?= =?UTF-8?q?=D0=BD=D0=B4=D0=BF=D0=BE=D0=B8=D0=BD=D1=82=D0=B0.=20Docstring?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agent/__init__.py | 4 ++-- src/agent/service.py | 30 ++++++++++++++++++++++++------ src/api/dependencies.py | 22 ++++++++++++++++++++-- src/api/external.py | 7 ++++--- src/main.py | 4 ++-- 5 files changed, 52 insertions(+), 15 deletions(-) diff --git a/src/agent/__init__.py b/src/agent/__init__.py index ccfc9bd..f184284 100644 --- a/src/agent/__init__.py +++ b/src/agent/__init__.py @@ -1,3 +1,3 @@ -from src.agent.service import AgentService, AgentChat +from src.agent.service import AgentService, AgentChat, ChatBusyError -__all__ = ["AgentService", "AgentChat"] +__all__ = ["AgentService", "AgentChat", "ChatBusyError"] diff --git a/src/agent/service.py b/src/agent/service.py index aa9b45f..170bb53 100644 --- a/src/agent/service.py +++ b/src/agent/service.py @@ -1,5 +1,5 @@ from typing import AsyncIterator, AsyncContextManager, Self -from abc import ABC, abstractmethod +from abc import abstractmethod from src.agent.base import create_agent from lambda_agent_api.server import ( @@ -8,11 +8,22 @@ from lambda_agent_api.server import ( ) -class ChatInUseError(Exception): +class ChatBusyError(Exception): + """ + Чат занят в другом блоке ``with`` + """ pass class AgentChat(AsyncContextManager[Self]): + """ + Объект для работы с конкретным чатом. + В то же время является своеобразным 'lock', который позволяет "захватывать" работу с чатом (Mutex). + Для контроля доступа используется ``with``. + Нельзя войти в блок ``with`` с определенным ``chat_id``, если он уже используется в другом блоке. + Перед вызовом любых методов (``astream`` и т. д.) необходимо войти в блок ``with``. + Объект получается из AgentService.chat(). + """ chat_id: int @abstractmethod @@ -21,7 +32,7 @@ class AgentChat(AsyncContextManager[Self]): class AgentService: - _instance = None + _instance = None # синглтон def __new__(cls): if cls._instance is None: @@ -30,7 +41,11 @@ class AgentService: return cls._instance class __AgentChat(AgentChat): - __locks: set[int] = set() + """ + Своеобразная реализация Mutex'а. Служит прослойкой до методов AgentService, но подставляет в них 'захваченный' chat_id. + """ + + __locks: set[int] = set() # чаты, которые уже "взяты" def __init__(self, service: AgentService, chat_id: int) -> None: self.__chat_id = chat_id @@ -43,7 +58,7 @@ class AgentService: async def __aenter__(self): if self.__chat_id in self.__locks: - raise ChatInUseError() + raise ChatBusyError() self.__locks.add(self.__chat_id) return self @@ -51,13 +66,16 @@ class AgentService: async def __aexit__(self, exc_type, exc_val, exc_tb): self.__locks.remove(self.__chat_id) - async def astream(self, text: str) -> AsyncIterator[AgentEventUnion]: + def astream(self, text: str) -> AsyncIterator[AgentEventUnion]: if not self.__chat_id in self.__locks: raise RuntimeError("Chat must be used in `with` statement") return self.__service._AgentService__astream(self.__chat_id, text) def chat(self, chat_id: int) -> AgentChat: + """ + Возвращает объект чата с заданным ID. Не проверяет Mutex. + """ return self.__AgentChat(self, chat_id) async def __astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]: diff --git a/src/api/dependencies.py b/src/api/dependencies.py index 01f2cf1..142b399 100644 --- a/src/api/dependencies.py +++ b/src/api/dependencies.py @@ -1,7 +1,7 @@ from typing import Annotated, AsyncGenerator -from fastapi import Depends +from fastapi import Depends, WebSocketException, status -from src.agent import AgentService, AgentChat +from src.agent import AgentService, AgentChat, ChatBusyError def get_agent_service() -> AgentService: @@ -12,3 +12,21 @@ async def get_chat(service: Annotated[AgentService, Depends(get_agent_service)], chat_id: int) -> AsyncGenerator[AgentChat]: async with service.chat(chat_id) as chat: yield chat + + +async def get_chat_ws(service: Annotated[AgentService, Depends(get_agent_service)], + chat_id: int) -> AsyncGenerator[AgentChat]: + """ + Версия ``get_chat`` для использования в WS эндпоинтах. + Ловит некоторые исключения (ChatBusyError) и оборачивает их в корректную WS ошибку. + Необходимо, т. к. глобальный exception handler в FastAPI предназначен для HTTP.\n + - ``ChatBusyError`` -> ``WebSocketException(status.WS_1008_POLICY_VIOLATION, reason=str(e))`` + """ + try: + gen = get_chat(service, chat_id) + yield await gen.__anext__() + except StopAsyncIteration: + pass + except ChatBusyError as e: + raise WebSocketException(status.WS_1008_POLICY_VIOLATION, + reason=str(e)) diff --git a/src/api/external.py b/src/api/external.py index 4985dee..d0d9445 100644 --- a/src/api/external.py +++ b/src/api/external.py @@ -11,16 +11,17 @@ from lambda_agent_api.server import ( from lambda_agent_api.client import ClientMessage, MsgUserMessage from src.agent import AgentChat -from src.api.dependencies import get_chat +from src.api.dependencies import get_chat_ws router = APIRouter() -@router.websocket("/agent_ws/{chat_id}/") +@router.websocket("/v1/agent_ws/{chat_id}/") async def websocket_endpoint( ws: WebSocket, - chat: Annotated[AgentChat, Depends(get_chat)], + # важно использовать именно _ws вариант, чтобы корректно обрабатывались исключения + chat: Annotated[AgentChat, Depends(get_chat_ws)], ): await ws.accept() await ws.send_text(MsgStatus().model_dump_json()) diff --git a/src/main.py b/src/main.py index 02418a8..9a5e83c 100644 --- a/src/main.py +++ b/src/main.py @@ -2,13 +2,13 @@ from contextlib import asynccontextmanager from fastapi import FastAPI -from src.api.dependencies import get_agent_service from src.api.external import router as ws_router +from src.agent import AgentService @asynccontextmanager async def lifespan(app: FastAPI): - get_agent_service() + AgentService() # инициализируем синглтон yield