82 lines
3.1 KiB
Python
82 lines
3.1 KiB
Python
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
|
from pydantic_core import ValidationError
|
|
|
|
from lambda_agent_api.server import (
|
|
MsgStatus,
|
|
MsgEventEnd,
|
|
MsgError,
|
|
)
|
|
from lambda_agent_api.client import ClientMessage, MsgUserMessage
|
|
|
|
from src.agent import AgentChat
|
|
from src.api.dependencies import get_chat_ws
|
|
from src.core.logger import get_logger
|
|
from src.core.correlation import (
|
|
generate_connection_id,
|
|
generate_message_id,
|
|
set_connection_id,
|
|
set_message_id,
|
|
clear_context,
|
|
)
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@router.websocket("/v1/agent_ws/{chat_id}/")
|
|
async def websocket_endpoint(
|
|
ws: WebSocket,
|
|
chat_id: str,
|
|
# важно использовать именно _ws вариант, чтобы корректно обрабатывались исключения
|
|
chat: Annotated[AgentChat, Depends(get_chat_ws)],
|
|
):
|
|
# Генерируем уникальный ID для этого подключения
|
|
connection_id = generate_connection_id()
|
|
set_connection_id(connection_id)
|
|
|
|
logger.info(f"WebSocket connection accepted for chat_id: {chat_id}")
|
|
await ws.accept()
|
|
await ws.send_text(MsgStatus().model_dump_json())
|
|
|
|
try:
|
|
while True:
|
|
raw = await ws.receive_text()
|
|
|
|
# Генерируем ID для каждого сообщения
|
|
message_id = generate_message_id()
|
|
set_message_id(message_id)
|
|
|
|
logger.trace(f"Received raw message: {len(raw)} characters for chat_id: {chat_id}")
|
|
try:
|
|
msg = ClientMessage.validate_json(raw)
|
|
except ValidationError as e:
|
|
logger.warning(f"Invalid JSON received from chat {chat_id}: {e}")
|
|
await ws.send_text(MsgError(code="BAD_REQUEST", details="Invalid message format").model_dump_json())
|
|
continue
|
|
await process_message(ws, chat, msg)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebSocket disconnected for chat_id: {chat_id}")
|
|
pass
|
|
except Exception as exc:
|
|
logger.exception("Unexpected error in websocket")
|
|
await ws.send_text(
|
|
MsgError(code="INTERNAL_ERROR", details=str(exc)).model_dump_json()
|
|
)
|
|
finally:
|
|
clear_context()
|
|
|
|
|
|
async def process_message(ws: WebSocket, chat: AgentChat, msg):
|
|
match msg:
|
|
case MsgUserMessage():
|
|
logger.debug(f"Processing user message for chat {chat.chat_id} (text length: {len(msg.text)}, attachments: {len(msg.attachments) if msg.attachments else 0})")
|
|
async for chunk in chat.astream(msg.text, msg.attachments):
|
|
logger.trace(f"Sending stream chunk to chat {chat.chat_id}: {chunk.__class__.__name__}")
|
|
await ws.send_text(chunk.model_dump_json())
|
|
logger.debug(f"Finished processing user message for chat {chat.chat_id}")
|
|
await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json()) # TODO: подставить реальное потребление токенов
|