поддержка chat_id

This commit is contained in:
Егор Кандрушин 2026-04-19 11:58:27 +03:00
parent 9cc7b45c23
commit ee192202b4
2 changed files with 7 additions and 9 deletions

View file

@ -13,11 +13,10 @@ class AgentService:
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance._agent = create_agent() cls._instance._agent = create_agent()
cls._instance._thread_id = "default"
return cls._instance return cls._instance
async def astream(self, text: str) -> AsyncIterator[AgentEventUnion]: async def astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]:
config = {"configurable": {"thread_id": self._thread_id}} config = {"configurable": {"thread_id": chat_id}}
# Используем astream_events для перехвата детальных событий (инструменты, чанки и т.д.) # Используем astream_events для перехвата детальных событий (инструменты, чанки и т.д.)
async for event in self._agent.astream_events( async for event in self._agent.astream_events(
@ -54,7 +53,5 @@ class AgentService:
yield MsgEventEnd(tokens_used=0) # потом заменить на метадату yield MsgEventEnd(tokens_used=0) # потом заменить на метадату
def get_agent_service() -> AgentService: def get_agent_service() -> AgentService:
return AgentService() return AgentService()

View file

@ -14,9 +14,10 @@ from src.agent import get_agent_service, AgentService
router = APIRouter() router = APIRouter()
@router.websocket("/agent_ws/") @router.websocket("/agent_ws/{chat_id}/")
async def websocket_endpoint( async def websocket_endpoint(
ws: WebSocket, ws: WebSocket,
chat_id: int,
agent_service: AgentService = Depends(get_agent_service), agent_service: AgentService = Depends(get_agent_service),
): ):
await ws.accept() await ws.accept()
@ -26,7 +27,7 @@ async def websocket_endpoint(
while True: while True:
raw = await ws.receive_text() raw = await ws.receive_text()
msg = ClientMessage.validate_json(raw) msg = ClientMessage.validate_json(raw)
await process_message(ws, msg, agent_service) await process_message(ws, chat_id, msg, agent_service)
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass
@ -36,9 +37,9 @@ async def websocket_endpoint(
) )
async def process_message(ws: WebSocket, msg, agent_service: AgentService): async def process_message(ws: WebSocket, chat_id: int, msg, agent_service: AgentService):
match msg: match msg:
case MsgUserMessage(): case MsgUserMessage():
async for chunk in agent_service.astream(msg.text): async for chunk in agent_service.astream(chat_id, msg.text):
await ws.send_text(chunk.model_dump_json()) await ws.send_text(chunk.model_dump_json())
await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json()) await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json())