отдельная сущность AgentChat для разделения чатов. Условный "семафор", гарантирующий, что с чатом может работать только одно подключение
This commit is contained in:
parent
ee192202b4
commit
69ec28037a
5 changed files with 72 additions and 15 deletions
|
|
@ -1,3 +1,3 @@
|
|||
from src.agent.service import AgentService, get_agent_service
|
||||
from src.agent.service import AgentService, AgentChat
|
||||
|
||||
__all__ = ["AgentService", "get_agent_service"]
|
||||
__all__ = ["AgentService", "AgentChat"]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import AsyncIterator
|
||||
from typing import AsyncIterator, AsyncContextManager, Self
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.agent.base import create_agent
|
||||
from lambda_agent_api.server import (
|
||||
|
|
@ -6,6 +7,19 @@ from lambda_agent_api.server import (
|
|||
MsgEventToolResult, MsgEventEnd
|
||||
)
|
||||
|
||||
|
||||
class ChatInUseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AgentChat(AsyncContextManager[Self]):
|
||||
chat_id: int
|
||||
|
||||
@abstractmethod
|
||||
def astream(self, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||
...
|
||||
|
||||
|
||||
class AgentService:
|
||||
_instance = None
|
||||
|
||||
|
|
@ -15,7 +29,38 @@ class AgentService:
|
|||
cls._instance._agent = create_agent()
|
||||
return cls._instance
|
||||
|
||||
async def astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||
class __AgentChat(AgentChat):
|
||||
__locks: set[int] = set()
|
||||
|
||||
def __init__(self, service: AgentService, chat_id: int) -> None:
|
||||
self.__chat_id = chat_id
|
||||
self.__service = service
|
||||
|
||||
# noinspection PyProtocol
|
||||
@property
|
||||
def chat_id(self) -> int:
|
||||
return self.__chat_id
|
||||
|
||||
async def __aenter__(self):
|
||||
if self.__chat_id in self.__locks:
|
||||
raise ChatInUseError()
|
||||
|
||||
self.__locks.add(self.__chat_id)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
self.__locks.remove(self.__chat_id)
|
||||
|
||||
async def astream(self, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||
if not self.__chat_id in self.__locks:
|
||||
raise RuntimeError("Chat must be used in `with` statement")
|
||||
|
||||
return self.__service._AgentService__astream(self.__chat_id, text)
|
||||
|
||||
def chat(self, chat_id: int) -> AgentChat:
|
||||
return self.__AgentChat(self, chat_id)
|
||||
|
||||
async def __astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
# Используем astream_events для перехвата детальных событий (инструменты, чанки и т.д.)
|
||||
|
|
@ -51,7 +96,3 @@ class AgentService:
|
|||
|
||||
# 3. В конце генерации отправляем событие завершения
|
||||
yield MsgEventEnd(tokens_used=0) # потом заменить на метадату
|
||||
|
||||
|
||||
def get_agent_service() -> AgentService:
|
||||
return AgentService()
|
||||
|
|
|
|||
14
src/api/dependencies.py
Normal file
14
src/api/dependencies.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from typing import Annotated, AsyncGenerator
|
||||
from fastapi import Depends
|
||||
|
||||
from src.agent import AgentService, AgentChat
|
||||
|
||||
|
||||
def get_agent_service() -> AgentService:
|
||||
return AgentService()
|
||||
|
||||
|
||||
async def get_chat(service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
chat_id: int) -> AsyncGenerator[AgentChat]:
|
||||
async with service.chat(chat_id) as chat:
|
||||
yield chat
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
||||
|
||||
from lambda_agent_api.server import (
|
||||
|
|
@ -8,7 +10,8 @@ from lambda_agent_api.server import (
|
|||
)
|
||||
from lambda_agent_api.client import ClientMessage, MsgUserMessage
|
||||
|
||||
from src.agent import get_agent_service, AgentService
|
||||
from src.agent import AgentChat
|
||||
from src.api.dependencies import get_chat
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -17,8 +20,7 @@ router = APIRouter()
|
|||
@router.websocket("/agent_ws/{chat_id}/")
|
||||
async def websocket_endpoint(
|
||||
ws: WebSocket,
|
||||
chat_id: int,
|
||||
agent_service: AgentService = Depends(get_agent_service),
|
||||
chat: Annotated[AgentChat, Depends(get_chat)],
|
||||
):
|
||||
await ws.accept()
|
||||
await ws.send_text(MsgStatus().model_dump_json())
|
||||
|
|
@ -27,7 +29,7 @@ async def websocket_endpoint(
|
|||
while True:
|
||||
raw = await ws.receive_text()
|
||||
msg = ClientMessage.validate_json(raw)
|
||||
await process_message(ws, chat_id, msg, agent_service)
|
||||
await process_message(ws, chat, msg)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
|
@ -37,9 +39,9 @@ async def websocket_endpoint(
|
|||
)
|
||||
|
||||
|
||||
async def process_message(ws: WebSocket, chat_id: int, msg, agent_service: AgentService):
|
||||
async def process_message(ws: WebSocket, chat: AgentChat, msg):
|
||||
match msg:
|
||||
case MsgUserMessage():
|
||||
async for chunk in agent_service.astream(chat_id, msg.text):
|
||||
async for chunk in chat.astream(msg.text):
|
||||
await ws.send_text(chunk.model_dump_json())
|
||||
await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json())
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from contextlib import asynccontextmanager
|
|||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from src.agent import get_agent_service
|
||||
from src.api.dependencies import get_agent_service
|
||||
from src.api.external import router as ws_router
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue