#5 Реализовать клиентскую часть канала общения. Исправлены все недочеты

This commit is contained in:
Ярослав Малинин 2026-04-01 10:30:49 +03:00
parent 0c19934008
commit 02467a4e12

View file

@ -1,176 +1,221 @@
import logging import logging
from typing import Callable, Optional from typing import Callable, Optional, AsyncIterator
import aiohttp import aiohttp
import asyncio import asyncio
from models import CM, SM, ClientMessage, ServerMessage from models import CM, SM, ClientMessage, ServerMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AgentException(Exception): class AgentException(Exception):
"""
Кастомное исключение для ошибок, полученных от агента.
Атрибуты:
code: Код ошибки из сообщения SM.Error
details: Детали ошибки
"""
def __init__(self, code: str, details: str): def __init__(self, code: str, details: str):
self.code = code self.code = code
self.details = details self.details = details
super().__init__(f"Agent error ({code}): {details}") super().__init__(f"Agent error ({code}): {details}")
class AgentBusyException(AgentException):
def __init__(self, details: str):
super().__init__("BUSY", details)
class AgentApi: class AgentApi:
""" def __init__(
Асинхронный клиент для взаимодействия с AI-агентом через WebSocket. self,
agent_id: str,
Класс инкапсулирует обмен сообщениями согласно контракту Pydantic url: str,
(CM для входящих, SM для исходящих сообщений). callback: Optional[Callable[[ServerMessage], None]] = None,
on_disconnect: Optional[Callable[['AgentApi'], None]] = None
Пример использования: ):
async with AgentApi("ws://localhost:8000", callback=my_callback) as agent: self.id = agent_id # ID агента для словаря
response = await agent.send_message("Hello, agent!")
async for chunk in response:
match chunk:
case SM.EventTextChunk():
print(chunk.text, end="")
print(f" [{response.tokens} токенов]")
Атрибуты:
url: URL WebSocket сервера агента
callback: Функция обратного вызова для событий не в процессе генерации
_session: aiohttp ClientSession для WebSocket соединения
_ws: Активное WebSocket соединение
_connected: Флаг состояния соединения
_queue: Очередь для передачи чанков ответа
_listen_task: Задача для прослушивания WS
"""
def __init__(self, url: str, callback: Optional[Callable[[ServerMessage], None]] = None):
"""
Инициализирует клиент агента.
Аргументы:
url: URL WebSocket сервера (например, "ws://localhost:8000")
callback: Функция обратного вызова для событий не в процессе генерации ответа
"""
self.url = url self.url = url
self.callback = callback self.callback = callback
self.on_disconnect = on_disconnect
self._session: aiohttp.ClientSession | None = None self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None self._ws: aiohttp.ClientWebSocketResponse | None = None
self._connected = False self._connected = False
self._queue = asyncio.Queue()
self._current_queue: asyncio.Queue | None = None
self._request_lock = asyncio.Lock()
self._listen_task: asyncio.Task | None = None self._listen_task: asyncio.Task | None = None
async def __aenter__(self): async def connect(self):
""" """Явное подключение к агенту."""
Входит в контекстный менеджер, устанавливает WebSocket соединение.
Возвращает:
self: Экземпляр AgentApi
Raises:
RuntimeError: Если не удалось установить соединение
"""
self._session = aiohttp.ClientSession() self._session = aiohttp.ClientSession()
try: try:
self._ws = await self._session.ws_connect(self.url, heartbeat=30) self._ws = await self._session.ws_connect(self.url, heartbeat=30)
self._connected = True
logger.info(f"Connected to agent at {self.url}")
# Ожидаем SM.Status при открытии соединения # Используем asyncio.wait_for вместо timeout в receive
msg = await self._ws.receive() # здесь отлавливаем исключение
msg = await asyncio.wait_for(self._ws.receive(), timeout=5.0)
if msg.type == aiohttp.WSMsgType.TEXT: if msg.type == aiohttp.WSMsgType.TEXT:
status_msg = ServerMessage.model_validate_json(msg.data) status_msg = ServerMessage.model_validate_json(msg.data)
if isinstance(status_msg, SM.Status): if isinstance(status_msg, SM.Status):
if self.callback: logger.info(f"Agent {self.id} is ready")
self.callback(status_msg)
logger.info("Agent is ready to accept messages")
else: else:
raise RuntimeError( raise AgentException(
f"Expected SM.Status on connection, got {status_msg.type}") "INVALID_STATUS", f"Expected SM.Status, got {status_msg.type}")
else: else:
raise RuntimeError( raise AgentException("UNEXPECTED_MSG_TYPE",
f"Unexpected message type on connection: {msg.type}") f"Unexpected message type: {msg.type}")
# Запускаем задачу для прослушивания WS self._connected = True
self._listen_task = asyncio.create_task(self._listen()) self._listen_task = asyncio.create_task(self._listen())
except asyncio.TimeoutError as e:
# Специфичная обработка таймаута
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
raise AgentException(
"TIMEOUT", "Agent did not send initial Status message within 5 seconds") from e
except AgentException:
# Перехватываем наши собственные ошибки (INVALID_STATUS, UNEXPECTED_MSG_TYPE),
# закрываем ресурсы и пробрасываем ошибку дальше без изменений!
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
raise
except Exception as e: except Exception as e:
await self._session.close() # Обработка всех остальных ошибок (например, aiohttp.ClientConnectionError)
raise RuntimeError(f"Failed to connect to agent: {e}") from e if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
return self # Можно оставить RuntimeError, а можно тоже завернуть в AgentException
raise AgentException(code="CONNECTION_ERROR",
async def __aexit__(self, exc_type, exc_val, exc_tb): details=f"Failed to connect agent {self.id}: {e}") from e
"""
Выходит из контекстного менеджера, закрывает WebSocket соединение.
Аргументы:
exc_type: Тип исключения, если оно произошло
exc_val: Значение исключения
exc_tb: Traceback исключения
"""
await self.close()
async def close(self): async def close(self):
""" """Явное ручное закрытие соединения."""
Закрывает WebSocket соединение и завершает сессию.
"""
self._connected = False
if self._listen_task and not self._listen_task.done(): if self._listen_task and not self._listen_task.done():
# Это вызовет CancelledError внутри _listen и запустит finally
self._listen_task.cancel() self._listen_task.cancel()
try: try:
await self._listen_task await self._listen_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
else:
await self._cleanup()
async def _cleanup(self):
logger.info(f"Cleaning up agent {self.id}")
self._connected = False # Сбрасываем в любом случае
if self._current_queue:
await self._current_queue.put(ConnectionError("Connection closed"))
if self._ws and not self._ws.closed: if self._ws and not self._ws.closed:
await self._ws.close() await self._ws.close()
logger.info("WebSocket connection closed")
if self._session: if self._session and not self._session.closed:
await self._session.close() await self._session.close()
logger.info("Client session closed") # aiohttp рекомендует небольшую паузу для корректного закрытия TCP соединений под капотом
await asyncio.sleep(0.0)
async def send_message(self, text: str) -> 'ResponseIterator': if self.on_disconnect:
try:
# Важно: callback синхронный, он не должен делать тяжелых IO операций
self.on_disconnect(self)
self.on_disconnect = None # Защита от двойного вызова
except Exception as e:
logger.error(f"Error in on_disconnect: {e}")
async def send_message(self, text: str) -> AsyncIterator[SM.AgentEvent]:
""" """
Отправляет сообщение от пользователя на сервер и возвращает итератор для получения ответа. Нативный асинхронный генератор.
Не требует отдельного класса ResponseIterator.
Аргументы: Гарантированно освобождает блокировку.
text: Текст сообщения пользователя
Возвращает:
ResponseIterator: Объект для асинхронной итерации по чанкам ответа
Raises:
RuntimeError: Если соединение не установлено
aiohttp.ClientError: Если не удалось отправить сообщение
""" """
if not self._connected or not self._ws: if not self._connected or not self._ws:
raise RuntimeError( raise AgentException(code="NOT_CONNECTED",
"Not connected to agent. Use 'async with' context manager.") details="Not connected. Call connect() first.")
message = CM.UserMessage( if self._request_lock.locked():
type=CM.Type.USER_MESSAGE, raise AgentBusyException(
text=text "Agent is currently processing another request")
)
# Блокируем параллельные запросы
# если идет стриминг ответа, то при попытки отправить новое сообщение будет ошибка - ее рейзим(делаем AgentBusyError)
# храню состояние занят ли агент или нет self.is_busy - если занят, то кидаем исключение (состояние занятости агента хранится lock)
await self._request_lock.acquire()
try: try:
await self._ws.send_str(message.model_dump_json()) self._current_queue = asyncio.Queue()
logger.debug(f"Sent user message: {text[:100]}...")
except Exception as e:
self._connected = False
raise aiohttp.ClientError(f"Failed to send message: {e}") from e
return ResponseIterator(self._queue) message = CM.UserMessage(
type=CM.Type.USER_MESSAGE,
text=text
)
await self._ws.send_str(message.model_dump_json())
logger.debug(f"[{self.id}] Sent message: {text[:50]}...")
# Читаем ответы из очереди
while True:
chunk = await self._current_queue.get()
# Если прилетела ошибка (от агента или из _listen)
if isinstance(chunk, Exception):
raise chunk
# Если конец ответа
if isinstance(chunk, SM.EventEnd):
# Если вам нужны tokens_used, можно сохранить их в атрибут self
break
yield chunk
finally:
# Если наш сервер (бэкенд агента) поддерживает команду отмены (например, специальный CM.Type.STOP_MESSAGE), то блок finally должен выглядеть так:
# if self._ws and not self._ws.closed:
# # Отправляем серверу команду "Хватит генерировать!"
# stop_msg = CM.UserMessage(type=CM.Type.STOP_GENERATION, text="")
# # (используйте правильную структуру вашего CM)
# # Запускаем отправку как таску, чтобы не блокировать finally
# asyncio.create_task(self._ws.send_str(stop_msg.model_dump_json()))
if self._current_queue:
# 1. Сохраняем ссылку на локальную переменную и отвязываем от класса.
# Это гарантирует, что _listen больше не положит сюда новые сообщения.
orphan_queue = self._current_queue
self._current_queue = None
# 2. Вычищаем (drain) очередь от сообщений, которые клиент не успел забрать
while not orphan_queue.empty():
try:
orphan_msg = orphan_queue.get_nowait()
# Если это исключение, просто логируем, в коллбек его кидать не стоит
if isinstance(orphan_msg, Exception):
logger.debug(
f"[{self.id}] Dropped exception from queue during cleanup: {orphan_msg}")
continue
# 3. Отправляем "мусорные/осиротевшие" куски в callback
if self.callback:
self.callback(orphan_msg)
else:
logger.debug(
f"[{self.id}] Dropped orphaned message during cleanup")
except asyncio.QueueEmpty:
break
# Освобождаем лок, чтобы разрешить новые запросы
if self._request_lock.locked():
self._request_lock.release()
async def _listen(self): async def _listen(self):
""" """"
Прослушивает WebSocket в фоне и обрабатывает сообщения. Прослушивание вебсокета.
""" """
try: try:
async for msg in self._ws: async for msg in self._ws:
@ -178,73 +223,63 @@ class AgentApi:
try: try:
outgoing_msg = ServerMessage.model_validate_json( outgoing_msg = ServerMessage.model_validate_json(
msg.data) msg.data)
logger.debug(
f"Received message of type: {outgoing_msg.type}")
if isinstance(outgoing_msg, SM.AgentEvent): if isinstance(outgoing_msg, SM.AgentEvent):
await self._queue.put(outgoing_msg) if self._current_queue:
elif isinstance(outgoing_msg, SM.Status): await self._current_queue.put(outgoing_msg)
if self.callback: # Если очереди нет (клиент отменил запрос), но токены идут — шлем их в коллбек
elif self.callback:
self.callback(outgoing_msg) self.callback(outgoing_msg)
logger.info("Agent status update") else:
logger.warning(
f"[{self.id}] AgentEvent without active request")
elif isinstance(outgoing_msg, SM.EventEnd):
if self._current_queue:
await self._current_queue.put(outgoing_msg)
elif isinstance(outgoing_msg, SM.Error): elif isinstance(outgoing_msg, SM.Error):
if self.callback: if self.callback:
self.callback(outgoing_msg) self.callback(outgoing_msg)
error = AgentException( error = AgentException(
outgoing_msg.code, outgoing_msg.details) outgoing_msg.code, outgoing_msg.details)
logger.error(f"Agent error: {error}") logger.error(f"[{self.id}] Agent error: {error}")
# Не бросаем исключение, а вызываем callback if self._current_queue:
await self._current_queue.put(error)
elif isinstance(outgoing_msg, SM.GracefulDisconnect): elif isinstance(outgoing_msg, SM.GracefulDisconnect):
if self.callback: if self.callback:
self.callback(outgoing_msg) self.callback(outgoing_msg)
logger.info("Agent gracefully disconnecting") logger.info(
self._connected = False f"[{self.id}] Gracefully disconnecting")
break break # Выход из цикла приведет к finally -> _cleanup
else: else:
# Другие типы через callback # написать лог о неизвестном сообщении -> вывод неожиданного сообщения в лог
logger.warning(
f"[{self.id}] Unknown message type: {outgoing_msg.type}"
)
if self.callback: if self.callback:
self.callback(outgoing_msg) self.callback(outgoing_msg)
except Exception as e: except Exception as e:
logger.error(f"Failed to deserialize message: {e}") logger.error(
if self.callback: f"[{self.id}] Failed to deserialize message: {e}")
# Можно создать специальное сообщение об ошибке if self._current_queue:
pass await self._current_queue.put(
AgentException(
"PARSE_ERROR", f"Validation failed: {e}")
)
elif msg.type == aiohttp.WSMsgType.ERROR: elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
error_msg = f"WebSocket error: {self._ws.exception()}" logger.error(
logger.error(error_msg) f"[{self.id}] WebSocket closed/error: {msg.type}")
self._connected = False
break break
elif msg.type == aiohttp.WSMsgType.CLOSED:
logger.info("WebSocket connection closed by server")
self._connected = False
break
except Exception as e:
logger.error(f"Error in listen loop: {e}")
self._connected = False
class ResponseIterator:
"""
Асинхронный итератор для получения чанков ответа от агента.
"""
def __init__(self, queue: asyncio.Queue):
self._queue = queue
self.tokens = 0
def __aiter__(self):
return self
async def __anext__(self):
try:
chunk = await self._queue.get()
if isinstance(chunk, SM.EventEnd):
self.tokens = chunk.tokens_used
raise StopAsyncIteration
return chunk
except asyncio.CancelledError: except asyncio.CancelledError:
raise StopAsyncIteration pass # Нормальное прерывание при self.close()
except Exception as e:
logger.error(f"[{self.id}] Error in listen loop: {e}")
finally:
# Как только цикл прослушивания закончился, запускаем процедуру очистки
await self._cleanup()