import logging from typing import Callable, Optional import aiohttp import asyncio from models import IM, OM, ClientMessage, ServerMessage logger = logging.getLogger(__name__) class AgentException(Exception): """ Кастомное исключение для ошибок, полученных от агента. Атрибуты: code: Код ошибки из сообщения OM.Error details: Детали ошибки """ def __init__(self, code: str, details: str): self.code = code self.details = details super().__init__(f"Agent error ({code}): {details}") class AgentApi: """ Асинхронный клиент для взаимодействия с AI-агентом через WebSocket. Класс инкапсулирует обмен сообщениями согласно контракту Pydantic (IM для входящих, OM для исходящих сообщений). Пример использования: async with AgentApi("ws://localhost:8000", callback=my_callback) as agent: response = await agent.send_message("Hello, agent!") async for chunk in response: match chunk: case OM.EventTextChunk(): print(chunk.text, end="") print(f" [{response.tokens} токенов]") Атрибуты: url: URL WebSocket сервера агента callback: Функция обратного вызова для событий не в процессе генерации _session: aiohttp ClientSession для WebSocket соединения _ws: Активное WebSocket соединение _connected: Флаг состояния соединения _queue: Очередь для передачи чанков ответа _listen_task: Задача для прослушивания WS """ def __init__(self, url: str, callback: Optional[Callable[[ServerMessage], None]] = None): """ Инициализирует клиент агента. Аргументы: url: URL WebSocket сервера (например, "ws://localhost:8000") callback: Функция обратного вызова для событий не в процессе генерации ответа """ self.url = url self.callback = callback self._session: aiohttp.ClientSession | None = None self._ws: aiohttp.ClientWebSocketResponse | None = None self._connected = False self._queue = asyncio.Queue() self._listen_task: asyncio.Task | None = None async def __aenter__(self): """ Входит в контекстный менеджер, устанавливает WebSocket соединение. Возвращает: self: Экземпляр AgentApi Raises: RuntimeError: Если не удалось установить соединение """ self._session = aiohttp.ClientSession() try: self._ws = await self._session.ws_connect(self.url, heartbeat=30) self._connected = True logger.info(f"Connected to agent at {self.url}") # Ожидаем OM.Status при открытии соединения msg = await self._ws.receive() if msg.type == aiohttp.WSMsgType.TEXT: status_msg = ServerMessage.model_validate_json(msg.data) if isinstance(status_msg, OM.Status): if self.callback: self.callback(status_msg) logger.info("Agent is ready to accept messages") else: raise RuntimeError( f"Expected OM.Status on connection, got {status_msg.type}") else: raise RuntimeError( f"Unexpected message type on connection: {msg.type}") # Запускаем задачу для прослушивания WS self._listen_task = asyncio.create_task(self._listen()) except Exception as e: await self._session.close() raise RuntimeError(f"Failed to connect to agent: {e}") from e return self async def __aexit__(self, exc_type, exc_val, exc_tb): """ Выходит из контекстного менеджера, закрывает WebSocket соединение. Аргументы: exc_type: Тип исключения, если оно произошло exc_val: Значение исключения exc_tb: Traceback исключения """ await self.close() async def close(self): """ Закрывает WebSocket соединение и завершает сессию. """ self._connected = False if self._listen_task and not self._listen_task.done(): self._listen_task.cancel() try: await self._listen_task except asyncio.CancelledError: pass if self._ws and not self._ws.closed: await self._ws.close() logger.info("WebSocket connection closed") if self._session: await self._session.close() logger.info("Client session closed") async def send_message(self, text: str) -> 'ResponseIterator': """ Отправляет сообщение от пользователя на сервер и возвращает итератор для получения ответа. Аргументы: text: Текст сообщения пользователя Возвращает: ResponseIterator: Объект для асинхронной итерации по чанкам ответа Raises: RuntimeError: Если соединение не установлено aiohttp.ClientError: Если не удалось отправить сообщение """ if not self._connected or not self._ws: raise RuntimeError( "Not connected to agent. Use 'async with' context manager.") message = IM.UserMessage( type=IM.Type.USER_MESSAGE, text=text ) try: await self._ws.send_str(message.model_dump_json()) logger.debug(f"Sent user message: {text[:100]}...") except Exception as e: self._connected = False raise aiohttp.ClientError(f"Failed to send message: {e}") from e return ResponseIterator(self._queue) async def _listen(self): """ Прослушивает WebSocket в фоне и обрабатывает сообщения. """ try: async for msg in self._ws: if msg.type == aiohttp.WSMsgType.TEXT: try: outgoing_msg = ServerMessage.model_validate_json( msg.data) logger.debug( f"Received message of type: {outgoing_msg.type}") if isinstance(outgoing_msg, OM.AgentEvent): await self._queue.put(outgoing_msg) elif isinstance(outgoing_msg, OM.Status): if self.callback: self.callback(outgoing_msg) logger.info("Agent status update") elif isinstance(outgoing_msg, OM.Error): if self.callback: self.callback(outgoing_msg) error = AgentException( outgoing_msg.code, outgoing_msg.details) logger.error(f"Agent error: {error}") # Не бросаем исключение, а вызываем callback elif isinstance(outgoing_msg, OM.GracefulDisconnect): if self.callback: self.callback(outgoing_msg) logger.info("Agent gracefully disconnecting") self._connected = False break else: # Другие типы через callback if self.callback: self.callback(outgoing_msg) except Exception as e: logger.error(f"Failed to deserialize message: {e}") if self.callback: # Можно создать специальное сообщение об ошибке pass elif msg.type == aiohttp.WSMsgType.ERROR: error_msg = f"WebSocket error: {self._ws.exception()}" logger.error(error_msg) self._connected = False break elif msg.type == aiohttp.WSMsgType.CLOSED: logger.info("WebSocket connection closed by server") self._connected = False break except Exception as e: logger.error(f"Error in listen loop: {e}") self._connected = False class ResponseIterator: """ Асинхронный итератор для получения чанков ответа от агента. """ def __init__(self, queue: asyncio.Queue): self._queue = queue self.tokens = 0 def __aiter__(self): return self async def __anext__(self): try: chunk = await self._queue.get() if isinstance(chunk, OM.EventEnd): self.tokens = chunk.tokens_used raise StopAsyncIteration return chunk except asyncio.CancelledError: raise StopAsyncIteration