agent/src/api/external.py

82 lines
3.1 KiB
Python

from typing import Annotated
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
from pydantic_core import ValidationError
from lambda_agent_api.server import (
MsgStatus,
MsgEventEnd,
MsgError,
)
from lambda_agent_api.client import ClientMessage, MsgUserMessage
from src.agent import AgentChat
from src.api.dependencies import get_chat_ws
from src.core.logger import get_logger
from src.core.correlation import (
generate_connection_id,
generate_message_id,
set_connection_id,
set_message_id,
clear_context,
)
router = APIRouter()
logger = get_logger(__name__)
@router.websocket("/v1/agent_ws/{chat_id}/")
async def websocket_endpoint(
ws: WebSocket,
chat_id: str,
# важно использовать именно _ws вариант, чтобы корректно обрабатывались исключения
chat: Annotated[AgentChat, Depends(get_chat_ws)],
):
# Генерируем уникальный ID для этого подключения
connection_id = generate_connection_id()
set_connection_id(connection_id)
logger.info(f"WebSocket connection accepted for chat_id: {chat_id}")
await ws.accept()
await ws.send_text(MsgStatus().model_dump_json())
try:
while True:
raw = await ws.receive_text()
# Генерируем ID для каждого сообщения
message_id = generate_message_id()
set_message_id(message_id)
logger.trace(f"Received raw message: {len(raw)} characters for chat_id: {chat_id}")
try:
msg = ClientMessage.validate_json(raw)
except ValidationError as e:
logger.warning(f"Invalid JSON received from chat {chat_id}: {e}")
await ws.send_text(MsgError(code="BAD_REQUEST", details="Invalid message format").model_dump_json())
continue
await process_message(ws, chat, msg)
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for chat_id: {chat_id}")
pass
except Exception as exc:
logger.exception("Unexpected error in websocket")
await ws.send_text(
MsgError(code="INTERNAL_ERROR", details=str(exc)).model_dump_json()
)
finally:
clear_context()
async def process_message(ws: WebSocket, chat: AgentChat, msg):
match msg:
case MsgUserMessage():
logger.debug(f"Processing user message for chat {chat.chat_id} (text length: {len(msg.text)}, attachments: {len(msg.attachments) if msg.attachments else 0})")
async for chunk in chat.astream(msg.text, msg.attachments):
logger.trace(f"Sending stream chunk to chat {chat.chat_id}: {chunk.__class__.__name__}")
await ws.send_text(chunk.model_dump_json())
logger.debug(f"Finished processing user message for chat {chat.chat_id}")
await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json()) # TODO: подставить реальное потребление токенов