import logging from typing import Callable, Optional, AsyncIterator import aiohttp import asyncio from lambda_agent_api.server import * from lambda_agent_api.client import * logger = logging.getLogger(__name__) class AgentException(Exception): def __init__(self, code: str, details: str): self.code = code self.details = details super().__init__(f"Agent error ({code}): {details}") class AgentBusyException(AgentException): def __init__(self, details: str): super().__init__("BUSY", details) class AgentApi: def __init__( self, agent_id: str, url: str, callback: Optional[Callable[[ServerMessage], None]] = None, on_disconnect: Optional[Callable[['AgentApi'], None]] = None ): self.id = agent_id # ID агента для словаря self.url = url self.callback = callback self.on_disconnect = on_disconnect self._session: aiohttp.ClientSession | None = None self._ws: aiohttp.ClientWebSocketResponse | None = None self._connected = False self._current_queue: asyncio.Queue | None = None self._request_lock = asyncio.Lock() self._listen_task: asyncio.Task | None = None async def connect(self): """Явное подключение к агенту.""" self._session = aiohttp.ClientSession() try: self._ws = await self._session.ws_connect(self.url, heartbeat=30) # Используем asyncio.wait_for вместо timeout в receive # здесь отлавливаем исключение msg = await asyncio.wait_for(self._ws.receive(), timeout=5.0) if msg.type == aiohttp.WSMsgType.TEXT: status_msg = ServerMessage.validate_json(msg.data) if isinstance(status_msg, MsgStatus): logger.info(f"Agent {self.id} is ready") else: raise AgentException( "INVALID_STATUS", f"Expected SM.Status, got {status_msg.type}") else: raise AgentException("UNEXPECTED_MSG_TYPE", f"Unexpected message type: {msg.type}") self._connected = True self._listen_task = asyncio.create_task(self._listen()) except asyncio.TimeoutError as e: # Специфичная обработка таймаута if self._ws and not self._ws.closed: await self._ws.close() if self._session and not self._session.closed: await self._session.close() raise AgentException( "TIMEOUT", "Agent did not send initial Status message within 5 seconds") from e except AgentException: # Перехватываем наши собственные ошибки (INVALID_STATUS, UNEXPECTED_MSG_TYPE), # закрываем ресурсы и пробрасываем ошибку дальше без изменений! if self._ws and not self._ws.closed: await self._ws.close() if self._session and not self._session.closed: await self._session.close() raise except Exception as e: # Обработка всех остальных ошибок (например, aiohttp.ClientConnectionError) if self._ws and not self._ws.closed: await self._ws.close() if self._session and not self._session.closed: await self._session.close() # Можно оставить RuntimeError, а можно тоже завернуть в AgentException raise AgentException(code="CONNECTION_ERROR", details=f"Failed to connect agent {self.id}: {e}") from e async def close(self): """Явное ручное закрытие соединения.""" if self._listen_task and not self._listen_task.done(): # Это вызовет CancelledError внутри _listen и запустит finally self._listen_task.cancel() try: await self._listen_task except asyncio.CancelledError: pass else: await self._cleanup() async def _cleanup(self): logger.info(f"Cleaning up agent {self.id}") self._connected = False # Сбрасываем в любом случае if self._current_queue: await self._current_queue.put(ConnectionError("Connection closed")) if self._ws and not self._ws.closed: await self._ws.close() if self._session and not self._session.closed: await self._session.close() # aiohttp рекомендует небольшую паузу для корректного закрытия TCP соединений под капотом await asyncio.sleep(0.0) if self.on_disconnect: try: # Важно: callback синхронный, он не должен делать тяжелых IO операций self.on_disconnect(self) self.on_disconnect = None # Защита от двойного вызова except Exception as e: logger.error(f"Error in on_disconnect: {e}") async def send_message(self, text: str) -> AsyncIterator[AgentEventUnion]: """ Нативный асинхронный генератор. Не требует отдельного класса ResponseIterator. Гарантированно освобождает блокировку. """ if not self._connected or not self._ws: raise AgentException(code="NOT_CONNECTED", details="Not connected. Call connect() first.") if self._request_lock.locked(): raise AgentBusyException( "Agent is currently processing another request") # Блокируем параллельные запросы # если идет стриминг ответа, то при попытки отправить новое сообщение будет ошибка - ее рейзим(делаем AgentBusyError) # храню состояние занят ли агент или нет self.is_busy - если занят, то кидаем исключение (состояние занятости агента хранится lock) await self._request_lock.acquire() try: self._current_queue = asyncio.Queue() message = MsgUserMessage( type=EClientMessage.USER_MESSAGE, text=text ) await self._ws.send_str(message.model_dump_json()) logger.debug(f"[{self.id}] Sent message: {text[:50]}...") # Читаем ответы из очереди while True: chunk = await self._current_queue.get() # Если прилетела ошибка (от агента или из _listen) if isinstance(chunk, Exception): raise chunk # Если конец ответа if isinstance(chunk, MsgEventEnd): # Если вам нужны tokens_used, можно сохранить их в атрибут self break yield chunk finally: # Если наш сервер (бэкенд агента) поддерживает команду отмены (например, специальный CM.Type.STOP_MESSAGE), то блок finally должен выглядеть так: # if self._ws and not self._ws.closed: # # Отправляем серверу команду "Хватит генерировать!" # stop_msg = CM.UserMessage(type=CM.Type.STOP_GENERATION, text="") # # (используйте правильную структуру вашего CM) # # Запускаем отправку как таску, чтобы не блокировать finally # asyncio.create_task(self._ws.send_str(stop_msg.model_dump_json())) if self._current_queue: # 1. Сохраняем ссылку на локальную переменную и отвязываем от класса. # Это гарантирует, что _listen больше не положит сюда новые сообщения. orphan_queue = self._current_queue self._current_queue = None # 2. Вычищаем (drain) очередь от сообщений, которые клиент не успел забрать while not orphan_queue.empty(): try: orphan_msg = orphan_queue.get_nowait() # Если это исключение, просто логируем, в коллбек его кидать не стоит if isinstance(orphan_msg, Exception): logger.debug( f"[{self.id}] Dropped exception from queue during cleanup: {orphan_msg}") continue # 3. Отправляем "мусорные/осиротевшие" куски в callback if self.callback: self.callback(orphan_msg) else: logger.debug( f"[{self.id}] Dropped orphaned message during cleanup") except asyncio.QueueEmpty: break # Освобождаем лок, чтобы разрешить новые запросы if self._request_lock.locked(): self._request_lock.release() async def _listen(self): """" Прослушивание вебсокета. """ try: async for msg in self._ws: if msg.type == aiohttp.WSMsgType.TEXT: try: outgoing_msg = ServerMessage.validate_json( msg.data) if isinstance(outgoing_msg, MsgEventTextChunk): if self._current_queue: await self._current_queue.put(outgoing_msg) # Если очереди нет (клиент отменил запрос), но токены идут — шлем их в коллбек elif self.callback: self.callback(outgoing_msg) else: logger.warning( f"[{self.id}] AgentEvent without active request") elif isinstance(outgoing_msg, MsgEventEnd): if self._current_queue: await self._current_queue.put(outgoing_msg) elif isinstance(outgoing_msg, MsgError): if self.callback: self.callback(outgoing_msg) error = AgentException( outgoing_msg.code, outgoing_msg.details) logger.error(f"[{self.id}] Agent error: {error}") if self._current_queue: await self._current_queue.put(error) elif isinstance(outgoing_msg, MsgGracefulDisconnect): if self.callback: self.callback(outgoing_msg) logger.info( f"[{self.id}] Gracefully disconnecting") break # Выход из цикла приведет к finally -> _cleanup else: # написать лог о неизвестном сообщении -> вывод неожиданного сообщения в лог logger.warning( f"[{self.id}] Unknown message type: {outgoing_msg.type}" ) if self.callback: self.callback(outgoing_msg) except Exception as e: logger.error( f"[{self.id}] Failed to deserialize message: {e}") if self._current_queue: await self._current_queue.put( AgentException( "PARSE_ERROR", f"Validation failed: {e}") ) elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): logger.error( f"[{self.id}] WebSocket closed/error: {msg.type}") break except asyncio.CancelledError: pass # Нормальное прерывание при self.close() except Exception as e: logger.error(f"[{self.id}] Error in listen loop: {e}") finally: # Как только цикл прослушивания закончился, запускаем процедуру очистки await self._cleanup()