поддержка chat_id
This commit is contained in:
parent
9cc7b45c23
commit
ee192202b4
2 changed files with 7 additions and 9 deletions
|
|
@ -13,11 +13,10 @@ class AgentService:
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
cls._instance._agent = create_agent()
|
cls._instance._agent = create_agent()
|
||||||
cls._instance._thread_id = "default"
|
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
async def astream(self, text: str) -> AsyncIterator[AgentEventUnion]:
|
async def astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||||
config = {"configurable": {"thread_id": self._thread_id}}
|
config = {"configurable": {"thread_id": chat_id}}
|
||||||
|
|
||||||
# Используем astream_events для перехвата детальных событий (инструменты, чанки и т.д.)
|
# Используем astream_events для перехвата детальных событий (инструменты, чанки и т.д.)
|
||||||
async for event in self._agent.astream_events(
|
async for event in self._agent.astream_events(
|
||||||
|
|
@ -54,7 +53,5 @@ class AgentService:
|
||||||
yield MsgEventEnd(tokens_used=0) # потом заменить на метадату
|
yield MsgEventEnd(tokens_used=0) # потом заменить на метадату
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_service() -> AgentService:
|
def get_agent_service() -> AgentService:
|
||||||
return AgentService()
|
return AgentService()
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,10 @@ from src.agent import get_agent_service, AgentService
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/agent_ws/")
|
@router.websocket("/agent_ws/{chat_id}/")
|
||||||
async def websocket_endpoint(
|
async def websocket_endpoint(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
|
chat_id: int,
|
||||||
agent_service: AgentService = Depends(get_agent_service),
|
agent_service: AgentService = Depends(get_agent_service),
|
||||||
):
|
):
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
|
|
@ -26,7 +27,7 @@ async def websocket_endpoint(
|
||||||
while True:
|
while True:
|
||||||
raw = await ws.receive_text()
|
raw = await ws.receive_text()
|
||||||
msg = ClientMessage.validate_json(raw)
|
msg = ClientMessage.validate_json(raw)
|
||||||
await process_message(ws, msg, agent_service)
|
await process_message(ws, chat_id, msg, agent_service)
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
pass
|
pass
|
||||||
|
|
@ -36,9 +37,9 @@ async def websocket_endpoint(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def process_message(ws: WebSocket, msg, agent_service: AgentService):
|
async def process_message(ws: WebSocket, chat_id: int, msg, agent_service: AgentService):
|
||||||
match msg:
|
match msg:
|
||||||
case MsgUserMessage():
|
case MsgUserMessage():
|
||||||
async for chunk in agent_service.astream(msg.text):
|
async for chunk in agent_service.astream(chat_id, msg.text):
|
||||||
await ws.send_text(chunk.model_dump_json())
|
await ws.send_text(chunk.model_dump_json())
|
||||||
await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json())
|
await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json())
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue