#10 Разделение истории сообщений через chat_id #13
7 changed files with 116 additions and 29 deletions
|
|
@ -4,7 +4,7 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
PYTHONUNBUFFERED=1
|
PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
RUN apt update && apt install make -y
|
RUN apt update && apt install make sudo -y
|
||||||
|
|
||||||
ENV AGENT_USER="agent"
|
ENV AGENT_USER="agent"
|
||||||
ENV WORKSPACE_DIR="/workspace/"
|
ENV WORKSPACE_DIR="/workspace/"
|
||||||
|
|
@ -52,8 +52,7 @@ ENV PATH="/app/.venv/bin:$PATH"
|
||||||
COPY Makefile ./
|
COPY Makefile ./
|
||||||
COPY .mk/ ./.mk/
|
COPY .mk/ ./.mk/
|
||||||
RUN chown root:root /app && chmod 700 /app
|
RUN chown root:root /app && chmod 700 /app
|
||||||
RUN apt install sudo -y && \
|
RUN echo "agent ALL=(ALL) NOPASSWD: /usr/bin/apt*" >> /etc/sudoers
|
||||||
echo "agent ALL=(ALL) NOPASSWD: /usr/bin/apt*" >> /etc/sudoers
|
|
||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,15 +19,11 @@ services:
|
||||||
agent_api: ${AGENT_API_PATH}
|
agent_api: ${AGENT_API_PATH}
|
||||||
volumes:
|
volumes:
|
||||||
- ./src:/app/src
|
- ./src:/app/src
|
||||||
- ${AGENT_API_PATH}:/agent-api/
|
- ${AGENT_API_PATH}:/agent_api/
|
||||||
- ./workspace:/workspace/
|
- ./workspace:/workspace/
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
cap_add: # для работы bwrap
|
|
||||||
- SYS_ADMIN
|
|
||||||
security_opt: # для работы bwrap
|
|
||||||
- seccomp:unconfined
|
|
||||||
profiles:
|
profiles:
|
||||||
- dev
|
- dev
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
from src.agent.service import AgentService, get_agent_service
|
from src.agent.service import AgentService, AgentChat, ChatBusyError
|
||||||
|
|
||||||
__all__ = ["AgentService", "get_agent_service"]
|
__all__ = ["AgentService", "AgentChat", "ChatBusyError"]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator, AsyncContextManager, Self
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
from src.agent.base import create_agent
|
from src.agent.base import create_agent
|
||||||
from lambda_agent_api.server import (
|
from lambda_agent_api.server import (
|
||||||
|
|
@ -6,18 +7,79 @@ from lambda_agent_api.server import (
|
||||||
MsgEventToolResult, MsgEventEnd
|
MsgEventToolResult, MsgEventEnd
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatBusyError(Exception):
|
||||||
|
"""
|
||||||
|
Чат занят в другом блоке ``with``
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AgentChat(AsyncContextManager[Self]):
|
||||||
|
"""
|
||||||
|
Объект для работы с конкретным чатом.
|
||||||
|
В то же время является своеобразным 'lock', который позволяет "захватывать" работу с чатом (Mutex).
|
||||||
|
Для контроля доступа используется ``with``.
|
||||||
|
Нельзя войти в блок ``with`` с определенным ``chat_id``, если он уже используется в другом блоке.
|
||||||
|
Перед вызовом любых методов (``astream`` и т. д.) необходимо войти в блок ``with``.
|
||||||
|
Объект получается из AgentService.chat().
|
||||||
|
"""
|
||||||
|
chat_id: int
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def astream(self, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class AgentService:
|
class AgentService:
|
||||||
_instance = None
|
_instance = None # синглтон
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
cls._instance._agent = create_agent()
|
cls._instance._agent = create_agent()
|
||||||
cls._instance._thread_id = "default"
|
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
async def astream(self, text: str) -> AsyncIterator[AgentEventUnion]:
|
class __AgentChat(AgentChat):
|
||||||
config = {"configurable": {"thread_id": self._thread_id}}
|
"""
|
||||||
|
Своеобразная реализация Mutex'а. Служит прослойкой до методов AgentService, но подставляет в них 'захваченный' chat_id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__locks: set[int] = set() # чаты, которые уже "взяты"
|
||||||
|
|
||||||
|
def __init__(self, service: AgentService, chat_id: int) -> None:
|
||||||
|
self.__chat_id = chat_id
|
||||||
|
self.__service = service
|
||||||
|
|
||||||
|
# noinspection PyProtocol
|
||||||
|
@property
|
||||||
|
def chat_id(self) -> int:
|
||||||
|
return self.__chat_id
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
if self.__chat_id in self.__locks:
|
||||||
|
raise ChatBusyError()
|
||||||
|
|
||||||
|
self.__locks.add(self.__chat_id)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.__locks.remove(self.__chat_id)
|
||||||
|
|
||||||
|
def astream(self, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||||
|
if not self.__chat_id in self.__locks:
|
||||||
|
raise RuntimeError("Chat must be used in `with` statement")
|
||||||
|
|
||||||
|
return self.__service._AgentService__astream(self.__chat_id, text)
|
||||||
|
|
||||||
|
def chat(self, chat_id: int) -> AgentChat:
|
||||||
|
"""
|
||||||
|
Возвращает объект чата с заданным ID. Не проверяет Mutex.
|
||||||
|
"""
|
||||||
|
return self.__AgentChat(self, chat_id)
|
||||||
|
|
||||||
|
async def __astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||||
|
config = {"configurable": {"thread_id": chat_id}}
|
||||||
|
|
||||||
# Используем astream_events для перехвата детальных событий (инструменты, чанки и т.д.)
|
# Используем astream_events для перехвата детальных событий (инструменты, чанки и т.д.)
|
||||||
async for event in self._agent.astream_events(
|
async for event in self._agent.astream_events(
|
||||||
|
|
@ -52,9 +114,3 @@ class AgentService:
|
||||||
|
|
||||||
# 3. В конце генерации отправляем событие завершения
|
# 3. В конце генерации отправляем событие завершения
|
||||||
yield MsgEventEnd(tokens_used=0) # потом заменить на метадату
|
yield MsgEventEnd(tokens_used=0) # потом заменить на метадату
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_service() -> AgentService:
|
|
||||||
return AgentService()
|
|
||||||
|
|
|
||||||
32
src/api/dependencies.py
Normal file
32
src/api/dependencies.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
from typing import Annotated, AsyncGenerator
|
||||||
|
from fastapi import Depends, WebSocketException, status
|
||||||
|
|
||||||
|
from src.agent import AgentService, AgentChat, ChatBusyError
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_service() -> AgentService:
|
||||||
|
return AgentService()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat(service: Annotated[AgentService, Depends(get_agent_service)],
|
||||||
|
chat_id: int) -> AsyncGenerator[AgentChat]:
|
||||||
|
async with service.chat(chat_id) as chat:
|
||||||
|
yield chat
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat_ws(service: Annotated[AgentService, Depends(get_agent_service)],
|
||||||
|
chat_id: int) -> AsyncGenerator[AgentChat]:
|
||||||
|
"""
|
||||||
|
Версия ``get_chat`` для использования в WS эндпоинтах.
|
||||||
|
Ловит некоторые исключения (ChatBusyError) и оборачивает их в корректную WS ошибку.
|
||||||
|
Необходимо, т. к. глобальный exception handler в FastAPI предназначен для HTTP.\n
|
||||||
|
- ``ChatBusyError`` -> ``WebSocketException(status.WS_1008_POLICY_VIOLATION, reason=str(e))``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
gen = get_chat(service, chat_id)
|
||||||
|
yield await gen.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass
|
||||||
|
except ChatBusyError as e:
|
||||||
|
raise WebSocketException(status.WS_1008_POLICY_VIOLATION,
|
||||||
|
reason=str(e))
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
||||||
|
|
||||||
from lambda_agent_api.server import (
|
from lambda_agent_api.server import (
|
||||||
|
|
@ -8,16 +10,18 @@ from lambda_agent_api.server import (
|
||||||
)
|
)
|
||||||
from lambda_agent_api.client import ClientMessage, MsgUserMessage
|
from lambda_agent_api.client import ClientMessage, MsgUserMessage
|
||||||
|
|
||||||
from src.agent import get_agent_service, AgentService
|
from src.agent import AgentChat
|
||||||
|
from src.api.dependencies import get_chat_ws
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/agent_ws/")
|
@router.websocket("/v1/agent_ws/{chat_id}/")
|
||||||
async def websocket_endpoint(
|
async def websocket_endpoint(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
agent_service: AgentService = Depends(get_agent_service),
|
# важно использовать именно _ws вариант, чтобы корректно обрабатывались исключения
|
||||||
|
chat: Annotated[AgentChat, Depends(get_chat_ws)],
|
||||||
):
|
):
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
await ws.send_text(MsgStatus().model_dump_json())
|
await ws.send_text(MsgStatus().model_dump_json())
|
||||||
|
|
@ -26,7 +30,7 @@ async def websocket_endpoint(
|
||||||
while True:
|
while True:
|
||||||
raw = await ws.receive_text()
|
raw = await ws.receive_text()
|
||||||
msg = ClientMessage.validate_json(raw)
|
msg = ClientMessage.validate_json(raw)
|
||||||
await process_message(ws, msg, agent_service)
|
await process_message(ws, chat, msg)
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
pass
|
pass
|
||||||
|
|
@ -36,9 +40,9 @@ async def websocket_endpoint(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def process_message(ws: WebSocket, msg, agent_service: AgentService):
|
async def process_message(ws: WebSocket, chat: AgentChat, msg):
|
||||||
match msg:
|
match msg:
|
||||||
case MsgUserMessage():
|
case MsgUserMessage():
|
||||||
async for chunk in agent_service.astream(msg.text):
|
async for chunk in chat.astream(msg.text):
|
||||||
await ws.send_text(chunk.model_dump_json())
|
await ws.send_text(chunk.model_dump_json())
|
||||||
await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json())
|
await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json())
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,13 @@ from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from src.agent import get_agent_service
|
|
||||||
from src.api.external import router as ws_router
|
from src.api.external import router as ws_router
|
||||||
|
from src.agent import AgentService
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
get_agent_service()
|
AgentService() # инициализируем синглтон
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue