from typing import Annotated from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends from pydantic_core import ValidationError from lambda_agent_api.server import ( MsgStatus, MsgEventEnd, MsgError, ) from lambda_agent_api.client import ClientMessage, MsgUserMessage from src.agent import AgentChat from src.api.dependencies import get_chat_ws from src.core.logger import get_logger from src.core.correlation import ( generate_connection_id, generate_message_id, set_connection_id, set_message_id, clear_context, ) router = APIRouter() logger = get_logger(__name__) @router.websocket("/v1/agent_ws/{chat_id}/") async def websocket_endpoint( ws: WebSocket, chat_id: str, # важно использовать именно _ws вариант, чтобы корректно обрабатывались исключения chat: Annotated[AgentChat, Depends(get_chat_ws)], ): # Генерируем уникальный ID для этого подключения connection_id = generate_connection_id() set_connection_id(connection_id) logger.info(f"WebSocket connection accepted for chat_id: {chat_id}") await ws.accept() await ws.send_text(MsgStatus().model_dump_json()) try: while True: raw = await ws.receive_text() # Генерируем ID для каждого сообщения message_id = generate_message_id() set_message_id(message_id) logger.trace(f"Received raw message: {len(raw)} characters for chat_id: {chat_id}") try: msg = ClientMessage.validate_json(raw) except ValidationError as e: logger.warning(f"Invalid JSON received from chat {chat_id}: {e}") await ws.send_text(MsgError(code="BAD_REQUEST", details="Invalid message format").model_dump_json()) continue await process_message(ws, chat, msg) except WebSocketDisconnect: logger.info(f"WebSocket disconnected for chat_id: {chat_id}") pass except Exception as exc: logger.exception("Unexpected error in websocket") await ws.send_text( MsgError(code="INTERNAL_ERROR", details=str(exc)).model_dump_json() ) finally: clear_context() async def process_message(ws: WebSocket, chat: AgentChat, msg): match msg: case MsgUserMessage(): logger.debug(f"Processing user message for chat {chat.chat_id} (text length: {len(msg.text)}, attachments: {len(msg.attachments) if msg.attachments else 0})") async for chunk in chat.astream(msg.text, msg.attachments): logger.trace(f"Sending stream chunk to chat {chat.chat_id}: {chunk.__class__.__name__}") await ws.send_text(chunk.model_dump_json()) logger.debug(f"Finished processing user message for chat {chat.chat_id}") await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json()) # TODO: подставить реальное потребление токенов