diff --git a/api/agent_api.py b/api/agent_api.py index 91686a1..21a30d2 100644 --- a/api/agent_api.py +++ b/api/agent_api.py @@ -1,176 +1,221 @@ import logging -from typing import Callable, Optional +from typing import Callable, Optional, AsyncIterator import aiohttp import asyncio from models import CM, SM, ClientMessage, ServerMessage - logger = logging.getLogger(__name__) class AgentException(Exception): - """ - Кастомное исключение для ошибок, полученных от агента. - - Атрибуты: - code: Код ошибки из сообщения SM.Error - details: Детали ошибки - """ - def __init__(self, code: str, details: str): self.code = code self.details = details super().__init__(f"Agent error ({code}): {details}") +class AgentBusyException(AgentException): + def __init__(self, details: str): + super().__init__("BUSY", details) + + class AgentApi: - """ - Асинхронный клиент для взаимодействия с AI-агентом через WebSocket. - - Класс инкапсулирует обмен сообщениями согласно контракту Pydantic - (CM для входящих, SM для исходящих сообщений). - - Пример использования: - async with AgentApi("ws://localhost:8000", callback=my_callback) as agent: - response = await agent.send_message("Hello, agent!") - async for chunk in response: - match chunk: - case SM.EventTextChunk(): - print(chunk.text, end="") - print(f" [{response.tokens} токенов]") - - Атрибуты: - url: URL WebSocket сервера агента - callback: Функция обратного вызова для событий не в процессе генерации - _session: aiohttp ClientSession для WebSocket соединения - _ws: Активное WebSocket соединение - _connected: Флаг состояния соединения - _queue: Очередь для передачи чанков ответа - _listen_task: Задача для прослушивания WS - """ - - def __init__(self, url: str, callback: Optional[Callable[[ServerMessage], None]] = None): - """ - Инициализирует клиент агента. - - Аргументы: - url: URL WebSocket сервера (например, "ws://localhost:8000") - callback: Функция обратного вызова для событий не в процессе генерации ответа - """ + def __init__( + self, + agent_id: str, + url: str, + callback: Optional[Callable[[ServerMessage], None]] = None, + on_disconnect: Optional[Callable[['AgentApi'], None]] = None + ): + self.id = agent_id # ID агента для словаря self.url = url self.callback = callback + self.on_disconnect = on_disconnect + self._session: aiohttp.ClientSession | None = None self._ws: aiohttp.ClientWebSocketResponse | None = None self._connected = False - self._queue = asyncio.Queue() + + self._current_queue: asyncio.Queue | None = None + self._request_lock = asyncio.Lock() self._listen_task: asyncio.Task | None = None - async def __aenter__(self): - """ - Входит в контекстный менеджер, устанавливает WebSocket соединение. - - Возвращает: - self: Экземпляр AgentApi - - Raises: - RuntimeError: Если не удалось установить соединение - """ + async def connect(self): + """Явное подключение к агенту.""" self._session = aiohttp.ClientSession() try: self._ws = await self._session.ws_connect(self.url, heartbeat=30) - self._connected = True - logger.info(f"Connected to agent at {self.url}") - # Ожидаем SM.Status при открытии соединения - msg = await self._ws.receive() + # Используем asyncio.wait_for вместо timeout в receive + # здесь отлавливаем исключение + msg = await asyncio.wait_for(self._ws.receive(), timeout=5.0) + if msg.type == aiohttp.WSMsgType.TEXT: status_msg = ServerMessage.model_validate_json(msg.data) if isinstance(status_msg, SM.Status): - if self.callback: - self.callback(status_msg) - logger.info("Agent is ready to accept messages") + logger.info(f"Agent {self.id} is ready") else: - raise RuntimeError( - f"Expected SM.Status on connection, got {status_msg.type}") + raise AgentException( + "INVALID_STATUS", f"Expected SM.Status, got {status_msg.type}") else: - raise RuntimeError( - f"Unexpected message type on connection: {msg.type}") + raise AgentException("UNEXPECTED_MSG_TYPE", + f"Unexpected message type: {msg.type}") - # Запускаем задачу для прослушивания WS + self._connected = True self._listen_task = asyncio.create_task(self._listen()) + except asyncio.TimeoutError as e: + # Специфичная обработка таймаута + if self._ws and not self._ws.closed: + await self._ws.close() + if self._session and not self._session.closed: + await self._session.close() + + raise AgentException( + "TIMEOUT", "Agent did not send initial Status message within 5 seconds") from e + + except AgentException: + # Перехватываем наши собственные ошибки (INVALID_STATUS, UNEXPECTED_MSG_TYPE), + # закрываем ресурсы и пробрасываем ошибку дальше без изменений! + if self._ws and not self._ws.closed: + await self._ws.close() + if self._session and not self._session.closed: + await self._session.close() + raise + except Exception as e: - await self._session.close() - raise RuntimeError(f"Failed to connect to agent: {e}") from e + # Обработка всех остальных ошибок (например, aiohttp.ClientConnectionError) + if self._ws and not self._ws.closed: + await self._ws.close() + if self._session and not self._session.closed: + await self._session.close() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """ - Выходит из контекстного менеджера, закрывает WebSocket соединение. - - Аргументы: - exc_type: Тип исключения, если оно произошло - exc_val: Значение исключения - exc_tb: Traceback исключения - """ - await self.close() + # Можно оставить RuntimeError, а можно тоже завернуть в AgentException + raise AgentException(code="CONNECTION_ERROR", + details=f"Failed to connect agent {self.id}: {e}") from e async def close(self): - """ - Закрывает WebSocket соединение и завершает сессию. - """ - self._connected = False + """Явное ручное закрытие соединения.""" if self._listen_task and not self._listen_task.done(): + # Это вызовет CancelledError внутри _listen и запустит finally self._listen_task.cancel() try: await self._listen_task except asyncio.CancelledError: pass + else: + await self._cleanup() + + async def _cleanup(self): + logger.info(f"Cleaning up agent {self.id}") + self._connected = False # Сбрасываем в любом случае + + if self._current_queue: + await self._current_queue.put(ConnectionError("Connection closed")) if self._ws and not self._ws.closed: await self._ws.close() - logger.info("WebSocket connection closed") - if self._session: + if self._session and not self._session.closed: await self._session.close() - logger.info("Client session closed") + # aiohttp рекомендует небольшую паузу для корректного закрытия TCP соединений под капотом + await asyncio.sleep(0.0) - async def send_message(self, text: str) -> 'ResponseIterator': + if self.on_disconnect: + try: + # Важно: callback синхронный, он не должен делать тяжелых IO операций + self.on_disconnect(self) + self.on_disconnect = None # Защита от двойного вызова + except Exception as e: + logger.error(f"Error in on_disconnect: {e}") + + async def send_message(self, text: str) -> AsyncIterator[SM.AgentEvent]: """ - Отправляет сообщение от пользователя на сервер и возвращает итератор для получения ответа. - - Аргументы: - text: Текст сообщения пользователя - - Возвращает: - ResponseIterator: Объект для асинхронной итерации по чанкам ответа - - Raises: - RuntimeError: Если соединение не установлено - aiohttp.ClientError: Если не удалось отправить сообщение + Нативный асинхронный генератор. + Не требует отдельного класса ResponseIterator. + Гарантированно освобождает блокировку. """ if not self._connected or not self._ws: - raise RuntimeError( - "Not connected to agent. Use 'async with' context manager.") + raise AgentException(code="NOT_CONNECTED", + details="Not connected. Call connect() first.") - message = CM.UserMessage( - type=CM.Type.USER_MESSAGE, - text=text - ) + if self._request_lock.locked(): + raise AgentBusyException( + "Agent is currently processing another request") + # Блокируем параллельные запросы + # если идет стриминг ответа, то при попытки отправить новое сообщение будет ошибка - ее рейзим(делаем AgentBusyError) + # храню состояние занят ли агент или нет self.is_busy - если занят, то кидаем исключение (состояние занятости агента хранится lock) + await self._request_lock.acquire() try: - await self._ws.send_str(message.model_dump_json()) - logger.debug(f"Sent user message: {text[:100]}...") - except Exception as e: - self._connected = False - raise aiohttp.ClientError(f"Failed to send message: {e}") from e + self._current_queue = asyncio.Queue() - return ResponseIterator(self._queue) + message = CM.UserMessage( + type=CM.Type.USER_MESSAGE, + text=text + ) + + await self._ws.send_str(message.model_dump_json()) + logger.debug(f"[{self.id}] Sent message: {text[:50]}...") + + # Читаем ответы из очереди + while True: + chunk = await self._current_queue.get() + + # Если прилетела ошибка (от агента или из _listen) + if isinstance(chunk, Exception): + raise chunk + + # Если конец ответа + if isinstance(chunk, SM.EventEnd): + # Если вам нужны tokens_used, можно сохранить их в атрибут self + break + + yield chunk + + finally: + # Если наш сервер (бэкенд агента) поддерживает команду отмены (например, специальный CM.Type.STOP_MESSAGE), то блок finally должен выглядеть так: + # if self._ws and not self._ws.closed: + # # Отправляем серверу команду "Хватит генерировать!" + # stop_msg = CM.UserMessage(type=CM.Type.STOP_GENERATION, text="") + # # (используйте правильную структуру вашего CM) + + # # Запускаем отправку как таску, чтобы не блокировать finally + # asyncio.create_task(self._ws.send_str(stop_msg.model_dump_json())) + if self._current_queue: + # 1. Сохраняем ссылку на локальную переменную и отвязываем от класса. + # Это гарантирует, что _listen больше не положит сюда новые сообщения. + orphan_queue = self._current_queue + self._current_queue = None + + # 2. Вычищаем (drain) очередь от сообщений, которые клиент не успел забрать + while not orphan_queue.empty(): + try: + orphan_msg = orphan_queue.get_nowait() + + # Если это исключение, просто логируем, в коллбек его кидать не стоит + if isinstance(orphan_msg, Exception): + logger.debug( + f"[{self.id}] Dropped exception from queue during cleanup: {orphan_msg}") + continue + + # 3. Отправляем "мусорные/осиротевшие" куски в callback + if self.callback: + self.callback(orphan_msg) + else: + logger.debug( + f"[{self.id}] Dropped orphaned message during cleanup") + + except asyncio.QueueEmpty: + break + + # Освобождаем лок, чтобы разрешить новые запросы + if self._request_lock.locked(): + self._request_lock.release() async def _listen(self): - """ - Прослушивает WebSocket в фоне и обрабатывает сообщения. + """" + Прослушивание вебсокета. """ try: async for msg in self._ws: @@ -178,73 +223,63 @@ class AgentApi: try: outgoing_msg = ServerMessage.model_validate_json( msg.data) - logger.debug( - f"Received message of type: {outgoing_msg.type}") if isinstance(outgoing_msg, SM.AgentEvent): - await self._queue.put(outgoing_msg) - elif isinstance(outgoing_msg, SM.Status): - if self.callback: + if self._current_queue: + await self._current_queue.put(outgoing_msg) + # Если очереди нет (клиент отменил запрос), но токены идут — шлем их в коллбек + elif self.callback: self.callback(outgoing_msg) - logger.info("Agent status update") + else: + logger.warning( + f"[{self.id}] AgentEvent without active request") + + elif isinstance(outgoing_msg, SM.EventEnd): + if self._current_queue: + await self._current_queue.put(outgoing_msg) + elif isinstance(outgoing_msg, SM.Error): if self.callback: self.callback(outgoing_msg) error = AgentException( outgoing_msg.code, outgoing_msg.details) - logger.error(f"Agent error: {error}") - # Не бросаем исключение, а вызываем callback + logger.error(f"[{self.id}] Agent error: {error}") + if self._current_queue: + await self._current_queue.put(error) + elif isinstance(outgoing_msg, SM.GracefulDisconnect): if self.callback: self.callback(outgoing_msg) - logger.info("Agent gracefully disconnecting") - self._connected = False - break + logger.info( + f"[{self.id}] Gracefully disconnecting") + break # Выход из цикла приведет к finally -> _cleanup + else: - # Другие типы через callback + # написать лог о неизвестном сообщении -> вывод неожиданного сообщения в лог + logger.warning( + f"[{self.id}] Unknown message type: {outgoing_msg.type}" + ) if self.callback: self.callback(outgoing_msg) except Exception as e: - logger.error(f"Failed to deserialize message: {e}") - if self.callback: - # Можно создать специальное сообщение об ошибке - pass + logger.error( + f"[{self.id}] Failed to deserialize message: {e}") + if self._current_queue: + await self._current_queue.put( + AgentException( + "PARSE_ERROR", f"Validation failed: {e}") + ) - elif msg.type == aiohttp.WSMsgType.ERROR: - error_msg = f"WebSocket error: {self._ws.exception()}" - logger.error(error_msg) - self._connected = False + elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): + logger.error( + f"[{self.id}] WebSocket closed/error: {msg.type}") break - elif msg.type == aiohttp.WSMsgType.CLOSED: - logger.info("WebSocket connection closed by server") - self._connected = False - break - - except Exception as e: - logger.error(f"Error in listen loop: {e}") - self._connected = False - - -class ResponseIterator: - """ - Асинхронный итератор для получения чанков ответа от агента. - """ - - def __init__(self, queue: asyncio.Queue): - self._queue = queue - self.tokens = 0 - - def __aiter__(self): - return self - - async def __anext__(self): - try: - chunk = await self._queue.get() - if isinstance(chunk, SM.EventEnd): - self.tokens = chunk.tokens_used - raise StopAsyncIteration - return chunk except asyncio.CancelledError: - raise StopAsyncIteration + pass # Нормальное прерывание при self.close() + except Exception as e: + logger.error(f"[{self.id}] Error in listen loop: {e}") + finally: + # Как только цикл прослушивания закончился, запускаем процедуру очистки + await self._cleanup()