diff --git a/Dockerfile b/Dockerfile index b9bc44e..e28cf48 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ ENV PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 WORKDIR /app -RUN apt update && apt install make -y +RUN apt update && apt install make sudo -y ENV AGENT_USER="agent" ENV WORKSPACE_DIR="/workspace/" @@ -52,8 +52,7 @@ ENV PATH="/app/.venv/bin:$PATH" COPY Makefile ./ COPY .mk/ ./.mk/ RUN chown root:root /app && chmod 700 /app -RUN apt install sudo -y && \ - echo "agent ALL=(ALL) NOPASSWD: /usr/bin/apt*" >> /etc/sudoers +RUN echo "agent ALL=(ALL) NOPASSWD: /usr/bin/apt*" >> /etc/sudoers EXPOSE 8000 diff --git a/docker-compose.yml b/docker-compose.yml index 27ba539..d639315 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,15 +19,11 @@ services: agent_api: ${AGENT_API_PATH} volumes: - ./src:/app/src - - ${AGENT_API_PATH}:/agent-api/ + - ${AGENT_API_PATH}:/agent_api/ - ./workspace:/workspace/ ports: - "8000:8000" env_file: - .env - cap_add: # для работы bwrap - - SYS_ADMIN - security_opt: # для работы bwrap - - seccomp:unconfined profiles: - dev diff --git a/src/agent/__init__.py b/src/agent/__init__.py index ff3ec8c..f184284 100644 --- a/src/agent/__init__.py +++ b/src/agent/__init__.py @@ -1,3 +1,3 @@ -from src.agent.service import AgentService, get_agent_service +from src.agent.service import AgentService, AgentChat, ChatBusyError -__all__ = ["AgentService", "get_agent_service"] +__all__ = ["AgentService", "AgentChat", "ChatBusyError"] diff --git a/src/agent/service.py b/src/agent/service.py index 05f50c4..170bb53 100644 --- a/src/agent/service.py +++ b/src/agent/service.py @@ -1,4 +1,5 @@ -from typing import AsyncIterator +from typing import AsyncIterator, AsyncContextManager, Self +from abc import abstractmethod from src.agent.base import create_agent from lambda_agent_api.server import ( @@ -6,18 +7,79 @@ from lambda_agent_api.server import ( MsgEventToolResult, MsgEventEnd ) + +class ChatBusyError(Exception): + """ + Чат занят в другом блоке ``with`` + """ + pass + + +class AgentChat(AsyncContextManager[Self]): + """ + Объект для работы с конкретным чатом. + В то же время является своеобразным 'lock', который позволяет "захватывать" работу с чатом (Mutex). + Для контроля доступа используется ``with``. + Нельзя войти в блок ``with`` с определенным ``chat_id``, если он уже используется в другом блоке. + Перед вызовом любых методов (``astream`` и т. д.) необходимо войти в блок ``with``. + Объект получается из AgentService.chat(). + """ + chat_id: int + + @abstractmethod + def astream(self, text: str) -> AsyncIterator[AgentEventUnion]: + ... + + class AgentService: - _instance = None + _instance = None # синглтон def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._agent = create_agent() - cls._instance._thread_id = "default" return cls._instance - async def astream(self, text: str) -> AsyncIterator[AgentEventUnion]: - config = {"configurable": {"thread_id": self._thread_id}} + class __AgentChat(AgentChat): + """ + Своеобразная реализация Mutex'а. Служит прослойкой до методов AgentService, но подставляет в них 'захваченный' chat_id. + """ + + __locks: set[int] = set() # чаты, которые уже "взяты" + + def __init__(self, service: AgentService, chat_id: int) -> None: + self.__chat_id = chat_id + self.__service = service + + # noinspection PyProtocol + @property + def chat_id(self) -> int: + return self.__chat_id + + async def __aenter__(self): + if self.__chat_id in self.__locks: + raise ChatBusyError() + + self.__locks.add(self.__chat_id) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.__locks.remove(self.__chat_id) + + def astream(self, text: str) -> AsyncIterator[AgentEventUnion]: + if not self.__chat_id in self.__locks: + raise RuntimeError("Chat must be used in `with` statement") + + return self.__service._AgentService__astream(self.__chat_id, text) + + def chat(self, chat_id: int) -> AgentChat: + """ + Возвращает объект чата с заданным ID. Не проверяет Mutex. + """ + return self.__AgentChat(self, chat_id) + + async def __astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]: + config = {"configurable": {"thread_id": chat_id}} # Используем astream_events для перехвата детальных событий (инструменты, чанки и т.д.) async for event in self._agent.astream_events( @@ -52,9 +114,3 @@ class AgentService: # 3. В конце генерации отправляем событие завершения yield MsgEventEnd(tokens_used=0) # потом заменить на метадату - - - - -def get_agent_service() -> AgentService: - return AgentService() diff --git a/src/api/dependencies.py b/src/api/dependencies.py new file mode 100644 index 0000000..142b399 --- /dev/null +++ b/src/api/dependencies.py @@ -0,0 +1,32 @@ +from typing import Annotated, AsyncGenerator +from fastapi import Depends, WebSocketException, status + +from src.agent import AgentService, AgentChat, ChatBusyError + + +def get_agent_service() -> AgentService: + return AgentService() + + +async def get_chat(service: Annotated[AgentService, Depends(get_agent_service)], + chat_id: int) -> AsyncGenerator[AgentChat]: + async with service.chat(chat_id) as chat: + yield chat + + +async def get_chat_ws(service: Annotated[AgentService, Depends(get_agent_service)], + chat_id: int) -> AsyncGenerator[AgentChat]: + """ + Версия ``get_chat`` для использования в WS эндпоинтах. + Ловит некоторые исключения (ChatBusyError) и оборачивает их в корректную WS ошибку. + Необходимо, т. к. глобальный exception handler в FastAPI предназначен для HTTP.\n + - ``ChatBusyError`` -> ``WebSocketException(status.WS_1008_POLICY_VIOLATION, reason=str(e))`` + """ + try: + gen = get_chat(service, chat_id) + yield await gen.__anext__() + except StopAsyncIteration: + pass + except ChatBusyError as e: + raise WebSocketException(status.WS_1008_POLICY_VIOLATION, + reason=str(e)) diff --git a/src/api/external.py b/src/api/external.py index c93f2cb..d0d9445 100644 --- a/src/api/external.py +++ b/src/api/external.py @@ -1,3 +1,5 @@ +from typing import Annotated + from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends from lambda_agent_api.server import ( @@ -8,16 +10,18 @@ from lambda_agent_api.server import ( ) from lambda_agent_api.client import ClientMessage, MsgUserMessage -from src.agent import get_agent_service, AgentService +from src.agent import AgentChat +from src.api.dependencies import get_chat_ws router = APIRouter() -@router.websocket("/agent_ws/") +@router.websocket("/v1/agent_ws/{chat_id}/") async def websocket_endpoint( ws: WebSocket, - agent_service: AgentService = Depends(get_agent_service), + # важно использовать именно _ws вариант, чтобы корректно обрабатывались исключения + chat: Annotated[AgentChat, Depends(get_chat_ws)], ): await ws.accept() await ws.send_text(MsgStatus().model_dump_json()) @@ -26,7 +30,7 @@ async def websocket_endpoint( while True: raw = await ws.receive_text() msg = ClientMessage.validate_json(raw) - await process_message(ws, msg, agent_service) + await process_message(ws, chat, msg) except WebSocketDisconnect: pass @@ -36,9 +40,9 @@ async def websocket_endpoint( ) -async def process_message(ws: WebSocket, msg, agent_service: AgentService): +async def process_message(ws: WebSocket, chat: AgentChat, msg): match msg: case MsgUserMessage(): - async for chunk in agent_service.astream(msg.text): + async for chunk in chat.astream(msg.text): await ws.send_text(chunk.model_dump_json()) await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json()) diff --git a/src/main.py b/src/main.py index c11618a..9a5e83c 100644 --- a/src/main.py +++ b/src/main.py @@ -2,13 +2,13 @@ from contextlib import asynccontextmanager from fastapi import FastAPI -from src.agent import get_agent_service from src.api.external import router as ws_router +from src.agent import AgentService @asynccontextmanager async def lifespan(app: FastAPI): - get_agent_service() + AgentService() # инициализируем синглтон yield