адаптация для установки пакетом

This commit is contained in:
Егор Кандрушин 2026-04-01 23:30:00 +03:00
parent a5ef5abac7
commit b34cbaf677
9 changed files with 260 additions and 258 deletions

View file

@ -1,285 +0,0 @@
import logging
from typing import Callable, Optional, AsyncIterator
import aiohttp
import asyncio
from models import CM, SM, ClientMessage, ServerMessage
logger = logging.getLogger(__name__)
class AgentException(Exception):
def __init__(self, code: str, details: str):
self.code = code
self.details = details
super().__init__(f"Agent error ({code}): {details}")
class AgentBusyException(AgentException):
def __init__(self, details: str):
super().__init__("BUSY", details)
class AgentApi:
def __init__(
self,
agent_id: str,
url: str,
callback: Optional[Callable[[ServerMessage], None]] = None,
on_disconnect: Optional[Callable[['AgentApi'], None]] = None
):
self.id = agent_id # ID агента для словаря
self.url = url
self.callback = callback
self.on_disconnect = on_disconnect
self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._connected = False
self._current_queue: asyncio.Queue | None = None
self._request_lock = asyncio.Lock()
self._listen_task: asyncio.Task | None = None
async def connect(self):
"""Явное подключение к агенту."""
self._session = aiohttp.ClientSession()
try:
self._ws = await self._session.ws_connect(self.url, heartbeat=30)
# Используем asyncio.wait_for вместо timeout в receive
# здесь отлавливаем исключение
msg = await asyncio.wait_for(self._ws.receive(), timeout=5.0)
if msg.type == aiohttp.WSMsgType.TEXT:
status_msg = ServerMessage.model_validate_json(msg.data)
if isinstance(status_msg, SM.Status):
logger.info(f"Agent {self.id} is ready")
else:
raise AgentException(
"INVALID_STATUS", f"Expected SM.Status, got {status_msg.type}")
else:
raise AgentException("UNEXPECTED_MSG_TYPE",
f"Unexpected message type: {msg.type}")
self._connected = True
self._listen_task = asyncio.create_task(self._listen())
except asyncio.TimeoutError as e:
# Специфичная обработка таймаута
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
raise AgentException(
"TIMEOUT", "Agent did not send initial Status message within 5 seconds") from e
except AgentException:
# Перехватываем наши собственные ошибки (INVALID_STATUS, UNEXPECTED_MSG_TYPE),
# закрываем ресурсы и пробрасываем ошибку дальше без изменений!
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
raise
except Exception as e:
# Обработка всех остальных ошибок (например, aiohttp.ClientConnectionError)
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
# Можно оставить RuntimeError, а можно тоже завернуть в AgentException
raise AgentException(code="CONNECTION_ERROR",
details=f"Failed to connect agent {self.id}: {e}") from e
async def close(self):
"""Явное ручное закрытие соединения."""
if self._listen_task and not self._listen_task.done():
# Это вызовет CancelledError внутри _listen и запустит finally
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
pass
else:
await self._cleanup()
async def _cleanup(self):
logger.info(f"Cleaning up agent {self.id}")
self._connected = False # Сбрасываем в любом случае
if self._current_queue:
await self._current_queue.put(ConnectionError("Connection closed"))
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
# aiohttp рекомендует небольшую паузу для корректного закрытия TCP соединений под капотом
await asyncio.sleep(0.0)
if self.on_disconnect:
try:
# Важно: callback синхронный, он не должен делать тяжелых IO операций
self.on_disconnect(self)
self.on_disconnect = None # Защита от двойного вызова
except Exception as e:
logger.error(f"Error in on_disconnect: {e}")
async def send_message(self, text: str) -> AsyncIterator[SM.AgentEvent]:
"""
Нативный асинхронный генератор.
Не требует отдельного класса ResponseIterator.
Гарантированно освобождает блокировку.
"""
if not self._connected or not self._ws:
raise AgentException(code="NOT_CONNECTED",
details="Not connected. Call connect() first.")
if self._request_lock.locked():
raise AgentBusyException(
"Agent is currently processing another request")
# Блокируем параллельные запросы
# если идет стриминг ответа, то при попытки отправить новое сообщение будет ошибка - ее рейзим(делаем AgentBusyError)
# храню состояние занят ли агент или нет self.is_busy - если занят, то кидаем исключение (состояние занятости агента хранится lock)
await self._request_lock.acquire()
try:
self._current_queue = asyncio.Queue()
message = CM.UserMessage(
type=CM.Type.USER_MESSAGE,
text=text
)
await self._ws.send_str(message.model_dump_json())
logger.debug(f"[{self.id}] Sent message: {text[:50]}...")
# Читаем ответы из очереди
while True:
chunk = await self._current_queue.get()
# Если прилетела ошибка (от агента или из _listen)
if isinstance(chunk, Exception):
raise chunk
# Если конец ответа
if isinstance(chunk, SM.EventEnd):
# Если вам нужны tokens_used, можно сохранить их в атрибут self
break
yield chunk
finally:
# Если наш сервер (бэкенд агента) поддерживает команду отмены (например, специальный CM.Type.STOP_MESSAGE), то блок finally должен выглядеть так:
# if self._ws and not self._ws.closed:
# # Отправляем серверу команду "Хватит генерировать!"
# stop_msg = CM.UserMessage(type=CM.Type.STOP_GENERATION, text="")
# # (используйте правильную структуру вашего CM)
# # Запускаем отправку как таску, чтобы не блокировать finally
# asyncio.create_task(self._ws.send_str(stop_msg.model_dump_json()))
if self._current_queue:
# 1. Сохраняем ссылку на локальную переменную и отвязываем от класса.
# Это гарантирует, что _listen больше не положит сюда новые сообщения.
orphan_queue = self._current_queue
self._current_queue = None
# 2. Вычищаем (drain) очередь от сообщений, которые клиент не успел забрать
while not orphan_queue.empty():
try:
orphan_msg = orphan_queue.get_nowait()
# Если это исключение, просто логируем, в коллбек его кидать не стоит
if isinstance(orphan_msg, Exception):
logger.debug(
f"[{self.id}] Dropped exception from queue during cleanup: {orphan_msg}")
continue
# 3. Отправляем "мусорные/осиротевшие" куски в callback
if self.callback:
self.callback(orphan_msg)
else:
logger.debug(
f"[{self.id}] Dropped orphaned message during cleanup")
except asyncio.QueueEmpty:
break
# Освобождаем лок, чтобы разрешить новые запросы
if self._request_lock.locked():
self._request_lock.release()
async def _listen(self):
""""
Прослушивание вебсокета.
"""
try:
async for msg in self._ws:
if msg.type == aiohttp.WSMsgType.TEXT:
try:
outgoing_msg = ServerMessage.model_validate_json(
msg.data)
if isinstance(outgoing_msg, SM.AgentEvent):
if self._current_queue:
await self._current_queue.put(outgoing_msg)
# Если очереди нет (клиент отменил запрос), но токены идут — шлем их в коллбек
elif self.callback:
self.callback(outgoing_msg)
else:
logger.warning(
f"[{self.id}] AgentEvent without active request")
elif isinstance(outgoing_msg, SM.EventEnd):
if self._current_queue:
await self._current_queue.put(outgoing_msg)
elif isinstance(outgoing_msg, SM.Error):
if self.callback:
self.callback(outgoing_msg)
error = AgentException(
outgoing_msg.code, outgoing_msg.details)
logger.error(f"[{self.id}] Agent error: {error}")
if self._current_queue:
await self._current_queue.put(error)
elif isinstance(outgoing_msg, SM.GracefulDisconnect):
if self.callback:
self.callback(outgoing_msg)
logger.info(
f"[{self.id}] Gracefully disconnecting")
break # Выход из цикла приведет к finally -> _cleanup
else:
# написать лог о неизвестном сообщении -> вывод неожиданного сообщения в лог
logger.warning(
f"[{self.id}] Unknown message type: {outgoing_msg.type}"
)
if self.callback:
self.callback(outgoing_msg)
except Exception as e:
logger.error(
f"[{self.id}] Failed to deserialize message: {e}")
if self._current_queue:
await self._current_queue.put(
AgentException(
"PARSE_ERROR", f"Validation failed: {e}")
)
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
logger.error(
f"[{self.id}] WebSocket closed/error: {msg.type}")
break
except asyncio.CancelledError:
pass # Нормальное прерывание при self.close()
except Exception as e:
logger.error(f"[{self.id}] Error in listen loop: {e}")
finally:
# Как только цикл прослушивания закончился, запускаем процедуру очистки
await self._cleanup()