diff --git a/api/agent_api.py b/api/agent_api.py index 0fb0c2c..5ee2359 100644 --- a/api/agent_api.py +++ b/api/agent_api.py @@ -1,6 +1,7 @@ import logging -from typing import AsyncGenerator +from typing import Callable, Optional import aiohttp +import asyncio from models import IM, OM, ClientMessage, ServerMessage @@ -30,29 +31,39 @@ class AgentApi: (IM для входящих, OM для исходящих сообщений). Пример использования: - async with AgentApi("ws://localhost:8000") as agent: - await agent.send_user_message("Hello, agent!") - async for message in agent.listen(): - print(message) + async with AgentApi("ws://localhost:8000", callback=my_callback) as agent: + response = await agent.send_message("Hello, agent!") + async for chunk in response: + match chunk: + case OM.EventTextChunk(): + print(chunk.text, end="") + print(f" [{response.tokens} токенов]") Атрибуты: url: URL WebSocket сервера агента + callback: Функция обратного вызова для событий не в процессе генерации _session: aiohttp ClientSession для WebSocket соединения _ws: Активное WebSocket соединение _connected: Флаг состояния соединения + _queue: Очередь для передачи чанков ответа + _listen_task: Задача для прослушивания WS """ - def __init__(self, url: str): + def __init__(self, url: str, callback: Optional[Callable[[ServerMessage], None]] = None): """ Инициализирует клиент агента. Аргументы: url: URL WebSocket сервера (например, "ws://localhost:8000") + callback: Функция обратного вызова для событий не в процессе генерации ответа """ self.url = url + self.callback = callback self._session: aiohttp.ClientSession | None = None self._ws: aiohttp.ClientWebSocketResponse | None = None self._connected = False + self._queue = asyncio.Queue() + self._listen_task: asyncio.Task | None = None async def __aenter__(self): """ @@ -69,6 +80,25 @@ class AgentApi: self._ws = await self._session.ws_connect(self.url, heartbeat=30) self._connected = True logger.info(f"Connected to agent at {self.url}") + + # Ожидаем OM.Status при открытии соединения + msg = await self._ws.receive() + if msg.type == aiohttp.WSMsgType.TEXT: + status_msg = ServerMessage.model_validate_json(msg.data) + if isinstance(status_msg, OM.Status): + if self.callback: + self.callback(status_msg) + logger.info("Agent is ready to accept messages") + else: + raise RuntimeError( + f"Expected OM.Status on connection, got {status_msg.type}") + else: + raise RuntimeError( + f"Unexpected message type on connection: {msg.type}") + + # Запускаем задачу для прослушивания WS + self._listen_task = asyncio.create_task(self._listen()) + except Exception as e: await self._session.close() raise RuntimeError(f"Failed to connect to agent: {e}") from e @@ -91,6 +121,13 @@ class AgentApi: Закрывает WebSocket соединение и завершает сессию. """ self._connected = False + if self._listen_task and not self._listen_task.done(): + self._listen_task.cancel() + try: + await self._listen_task + except asyncio.CancelledError: + pass + if self._ws and not self._ws.closed: await self._ws.close() logger.info("WebSocket connection closed") @@ -99,13 +136,16 @@ class AgentApi: await self._session.close() logger.info("Client session closed") - async def send_user_message(self, text: str) -> None: + async def send_message(self, text: str) -> 'ResponseIterator': """ - Отправляет сообщение от пользователя на сервер. + Отправляет сообщение от пользователя на сервер и возвращает итератор для получения ответа. Аргументы: text: Текст сообщения пользователя + Возвращает: + ResponseIterator: Объект для асинхронной итерации по чанкам ответа + Raises: RuntimeError: Если соединение не установлено aiohttp.ClientError: Если не удалось отправить сообщение @@ -126,34 +166,12 @@ class AgentApi: self._connected = False raise aiohttp.ClientError(f"Failed to send message: {e}") from e - async def listen(self) -> AsyncGenerator[ServerMessage, None]: + return ResponseIterator(self._queue) + + async def _listen(self): """ - Читает поток сообщений от сервера и десериализует их. - - Это асинхронный генератор, который обрабатывает: - - OM.Status: Логирует о готовности - - OM.Error: Бросает AgentException - - OM.GracefulDisconnect: Корректно завершает соединение - - OM.AgentEvent: Возвращает объект события - - Yields: - ServerMessage: Десериализованное сообщение от сервера - - Raises: - RuntimeError: Если соединение не установлено - AgentException: Если сервер отправил ошибку - - Пример: - async for message in agent.listen(): - if isinstance(message, OM.Status): - print("Agent is ready") - elif isinstance(message, OM.AgentEvent): - print(f"Agent event: {message.subtype}") + Прослушивает WebSocket в фоне и обрабатывает сообщения. """ - if not self._connected or not self._ws: - raise RuntimeError( - "Not connected to agent. Use 'async with' context manager.") - try: async for msg in self._ws: if msg.type == aiohttp.WSMsgType.TEXT: @@ -163,57 +181,70 @@ class AgentApi: logger.debug( f"Received message of type: {outgoing_msg.type}") - # Обработка специальных событий - if isinstance(outgoing_msg, OM.Status): - logger.info("Agent is ready to accept messages") - yield outgoing_msg - + if isinstance(outgoing_msg, OM.AgentEvent): + await self._queue.put(outgoing_msg) + elif isinstance(outgoing_msg, OM.Status): + if self.callback: + self.callback(outgoing_msg) + logger.info("Agent status update") elif isinstance(outgoing_msg, OM.Error): + if self.callback: + self.callback(outgoing_msg) error = AgentException( outgoing_msg.code, outgoing_msg.details) logger.error(f"Agent error: {error}") - raise error - + # Не бросаем исключение, а вызываем callback elif isinstance(outgoing_msg, OM.GracefulDisconnect): + if self.callback: + self.callback(outgoing_msg) logger.info("Agent gracefully disconnecting") self._connected = False - yield outgoing_msg break - else: - # OM.AgentEvent и другие типы - yield outgoing_msg + # Другие типы через callback + if self.callback: + self.callback(outgoing_msg) - except AgentException: - # Пробрасываем исключения агента дальше - raise except Exception as e: logger.error(f"Failed to deserialize message: {e}") - raise RuntimeError( - f"Message deserialization error: {e}") from e + if self.callback: + # Можно создать специальное сообщение об ошибке + pass elif msg.type == aiohttp.WSMsgType.ERROR: error_msg = f"WebSocket error: {self._ws.exception()}" logger.error(error_msg) self._connected = False - raise RuntimeError(error_msg) + break elif msg.type == aiohttp.WSMsgType.CLOSED: logger.info("WebSocket connection closed by server") self._connected = False break - except AgentException: - raise - except RuntimeError: - raise except Exception as e: logger.error(f"Error in listen loop: {e}") self._connected = False - raise RuntimeError(f"Listen loop error: {e}") from e - finally: - # Do not reset _connected here - it should only be reset on: - # - explicit close() call - # - OM.GracefulDisconnect handling - # - actual connection errors - pass + + +class ResponseIterator: + """ + Асинхронный итератор для получения чанков ответа от агента. + """ + + def __init__(self, queue: asyncio.Queue): + self._queue = queue + self.tokens = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + try: + chunk = await self._queue.get() + if isinstance(chunk, OM.EventEnd): + self.tokens = chunk.tokens_used + raise StopAsyncIteration + return chunk + except asyncio.CancelledError: + raise StopAsyncIteration