From ee192202b4f1fe6e7c479aff88f1a2da031325ce Mon Sep 17 00:00:00 2001 From: MrKan Date: Sun, 19 Apr 2026 11:58:27 +0300 Subject: [PATCH] =?UTF-8?q?=D0=BF=D0=BE=D0=B4=D0=B4=D0=B5=D1=80=D0=B6?= =?UTF-8?q?=D0=BA=D0=B0=20chat=5Fid?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agent/service.py | 7 ++----- src/api/external.py | 9 +++++---- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/agent/service.py b/src/agent/service.py index 05f50c4..4cdae6d 100644 --- a/src/agent/service.py +++ b/src/agent/service.py @@ -13,11 +13,10 @@ class AgentService: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._agent = create_agent() - cls._instance._thread_id = "default" return cls._instance - async def astream(self, text: str) -> AsyncIterator[AgentEventUnion]: - config = {"configurable": {"thread_id": self._thread_id}} + async def astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]: + config = {"configurable": {"thread_id": chat_id}} # Используем astream_events для перехвата детальных событий (инструменты, чанки и т.д.) async for event in self._agent.astream_events( @@ -54,7 +53,5 @@ class AgentService: yield MsgEventEnd(tokens_used=0) # потом заменить на метадату - - def get_agent_service() -> AgentService: return AgentService() diff --git a/src/api/external.py b/src/api/external.py index c93f2cb..9a84391 100644 --- a/src/api/external.py +++ b/src/api/external.py @@ -14,9 +14,10 @@ from src.agent import get_agent_service, AgentService router = APIRouter() -@router.websocket("/agent_ws/") +@router.websocket("/agent_ws/{chat_id}/") async def websocket_endpoint( ws: WebSocket, + chat_id: int, agent_service: AgentService = Depends(get_agent_service), ): await ws.accept() @@ -26,7 +27,7 @@ async def websocket_endpoint( while True: raw = await ws.receive_text() msg = ClientMessage.validate_json(raw) - await process_message(ws, msg, agent_service) + await process_message(ws, chat_id, msg, agent_service) except WebSocketDisconnect: pass @@ -36,9 +37,9 @@ async def websocket_endpoint( ) -async def process_message(ws: WebSocket, msg, agent_service: AgentService): +async def process_message(ws: WebSocket, chat_id: int, msg, agent_service: AgentService): match msg: case MsgUserMessage(): - async for chunk in agent_service.astream(msg.text): + async for chunk in agent_service.astream(chat_id, msg.text): await ws.send_text(chunk.model_dump_json()) await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json())