сохранение состояние между вызовами

This commit is contained in:
Егор Кандрушин 2026-04-03 15:57:32 +03:00
parent 635986ba70
commit d4d775305f
5 changed files with 63 additions and 36 deletions

View file

@ -1,4 +1,4 @@
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
from lambda_agent_api.server import (
MsgStatus,
@ -8,14 +8,17 @@ from lambda_agent_api.server import (
)
from lambda_agent_api.client import ClientMessage, MsgUserMessage
from src.agent.base import agent
from src.agent import get_agent_service, AgentService
router = APIRouter()
@router.websocket("/agent_ws/")
async def websocket_endpoint(ws: WebSocket):
async def websocket_endpoint(
ws: WebSocket,
agent_service: AgentService = Depends(get_agent_service),
):
await ws.accept()
await ws.send_text(MsgStatus().model_dump_json())
@ -23,7 +26,7 @@ async def websocket_endpoint(ws: WebSocket):
while True:
raw = await ws.receive_text()
msg = ClientMessage.validate_json(raw)
await process_message(ws, msg)
await process_message(ws, msg, agent_service)
except WebSocketDisconnect:
pass
@ -33,23 +36,9 @@ async def websocket_endpoint(ws: WebSocket):
)
async def process_message(
ws: WebSocket, msg
): # msg должно быть ClientMessage (аннотация не работает из-за TypeAdapter)
async def process_message(ws: WebSocket, msg, agent_service: AgentService):
match msg:
case MsgUserMessage():
await handle_user_message(ws, msg.text)
async def handle_user_message(ws: WebSocket, text: str):
tokens_used = 0
async for event in agent.astream({"messages": [{"role": "user", "content": text}]}):
messages = event.get("messages") or event.get("model", {}).get("messages", [])
if messages:
last_msg = messages[-1]
content = getattr(last_msg, "content", None)
if isinstance(content, str) and content.strip():
await ws.send_text(MsgEventTextChunk(text=content.strip()).model_dump_json())
await ws.send_text(MsgEventEnd(tokens_used=tokens_used).model_dump_json())
async for chunk in agent_service.astream(msg.text):
await ws.send_text(MsgEventTextChunk(text=chunk).model_dump_json())
await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json())