agent_api/api/agent_api.py

285 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from typing import Callable, Optional, AsyncIterator
import aiohttp
import asyncio
from models import CM, SM, ClientMessage, ServerMessage
logger = logging.getLogger(__name__)
class AgentException(Exception):
def __init__(self, code: str, details: str):
self.code = code
self.details = details
super().__init__(f"Agent error ({code}): {details}")
class AgentBusyException(AgentException):
def __init__(self, details: str):
super().__init__("BUSY", details)
class AgentApi:
def __init__(
self,
agent_id: str,
url: str,
callback: Optional[Callable[[ServerMessage], None]] = None,
on_disconnect: Optional[Callable[['AgentApi'], None]] = None
):
self.id = agent_id # ID агента для словаря
self.url = url
self.callback = callback
self.on_disconnect = on_disconnect
self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._connected = False
self._current_queue: asyncio.Queue | None = None
self._request_lock = asyncio.Lock()
self._listen_task: asyncio.Task | None = None
async def connect(self):
"""Явное подключение к агенту."""
self._session = aiohttp.ClientSession()
try:
self._ws = await self._session.ws_connect(self.url, heartbeat=30)
# Используем asyncio.wait_for вместо timeout в receive
# здесь отлавливаем исключение
msg = await asyncio.wait_for(self._ws.receive(), timeout=5.0)
if msg.type == aiohttp.WSMsgType.TEXT:
status_msg = ServerMessage.model_validate_json(msg.data)
if isinstance(status_msg, SM.Status):
logger.info(f"Agent {self.id} is ready")
else:
raise AgentException(
"INVALID_STATUS", f"Expected SM.Status, got {status_msg.type}")
else:
raise AgentException("UNEXPECTED_MSG_TYPE",
f"Unexpected message type: {msg.type}")
self._connected = True
self._listen_task = asyncio.create_task(self._listen())
except asyncio.TimeoutError as e:
# Специфичная обработка таймаута
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
raise AgentException(
"TIMEOUT", "Agent did not send initial Status message within 5 seconds") from e
except AgentException:
# Перехватываем наши собственные ошибки (INVALID_STATUS, UNEXPECTED_MSG_TYPE),
# закрываем ресурсы и пробрасываем ошибку дальше без изменений!
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
raise
except Exception as e:
# Обработка всех остальных ошибок (например, aiohttp.ClientConnectionError)
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
# Можно оставить RuntimeError, а можно тоже завернуть в AgentException
raise AgentException(code="CONNECTION_ERROR",
details=f"Failed to connect agent {self.id}: {e}") from e
async def close(self):
"""Явное ручное закрытие соединения."""
if self._listen_task and not self._listen_task.done():
# Это вызовет CancelledError внутри _listen и запустит finally
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
pass
else:
await self._cleanup()
async def _cleanup(self):
logger.info(f"Cleaning up agent {self.id}")
self._connected = False # Сбрасываем в любом случае
if self._current_queue:
await self._current_queue.put(ConnectionError("Connection closed"))
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
# aiohttp рекомендует небольшую паузу для корректного закрытия TCP соединений под капотом
await asyncio.sleep(0.0)
if self.on_disconnect:
try:
# Важно: callback синхронный, он не должен делать тяжелых IO операций
self.on_disconnect(self)
self.on_disconnect = None # Защита от двойного вызова
except Exception as e:
logger.error(f"Error in on_disconnect: {e}")
async def send_message(self, text: str) -> AsyncIterator[SM.AgentEvent]:
"""
Нативный асинхронный генератор.
Не требует отдельного класса ResponseIterator.
Гарантированно освобождает блокировку.
"""
if not self._connected or not self._ws:
raise AgentException(code="NOT_CONNECTED",
details="Not connected. Call connect() first.")
if self._request_lock.locked():
raise AgentBusyException(
"Agent is currently processing another request")
# Блокируем параллельные запросы
# если идет стриминг ответа, то при попытки отправить новое сообщение будет ошибка - ее рейзим(делаем AgentBusyError)
# храню состояние занят ли агент или нет self.is_busy - если занят, то кидаем исключение (состояние занятости агента хранится lock)
await self._request_lock.acquire()
try:
self._current_queue = asyncio.Queue()
message = CM.UserMessage(
type=CM.Type.USER_MESSAGE,
text=text
)
await self._ws.send_str(message.model_dump_json())
logger.debug(f"[{self.id}] Sent message: {text[:50]}...")
# Читаем ответы из очереди
while True:
chunk = await self._current_queue.get()
# Если прилетела ошибка (от агента или из _listen)
if isinstance(chunk, Exception):
raise chunk
# Если конец ответа
if isinstance(chunk, SM.EventEnd):
# Если вам нужны tokens_used, можно сохранить их в атрибут self
break
yield chunk
finally:
# Если наш сервер (бэкенд агента) поддерживает команду отмены (например, специальный CM.Type.STOP_MESSAGE), то блок finally должен выглядеть так:
# if self._ws and not self._ws.closed:
# # Отправляем серверу команду "Хватит генерировать!"
# stop_msg = CM.UserMessage(type=CM.Type.STOP_GENERATION, text="")
# # (используйте правильную структуру вашего CM)
# # Запускаем отправку как таску, чтобы не блокировать finally
# asyncio.create_task(self._ws.send_str(stop_msg.model_dump_json()))
if self._current_queue:
# 1. Сохраняем ссылку на локальную переменную и отвязываем от класса.
# Это гарантирует, что _listen больше не положит сюда новые сообщения.
orphan_queue = self._current_queue
self._current_queue = None
# 2. Вычищаем (drain) очередь от сообщений, которые клиент не успел забрать
while not orphan_queue.empty():
try:
orphan_msg = orphan_queue.get_nowait()
# Если это исключение, просто логируем, в коллбек его кидать не стоит
if isinstance(orphan_msg, Exception):
logger.debug(
f"[{self.id}] Dropped exception from queue during cleanup: {orphan_msg}")
continue
# 3. Отправляем "мусорные/осиротевшие" куски в callback
if self.callback:
self.callback(orphan_msg)
else:
logger.debug(
f"[{self.id}] Dropped orphaned message during cleanup")
except asyncio.QueueEmpty:
break
# Освобождаем лок, чтобы разрешить новые запросы
if self._request_lock.locked():
self._request_lock.release()
async def _listen(self):
""""
Прослушивание вебсокета.
"""
try:
async for msg in self._ws:
if msg.type == aiohttp.WSMsgType.TEXT:
try:
outgoing_msg = ServerMessage.model_validate_json(
msg.data)
if isinstance(outgoing_msg, SM.AgentEvent):
if self._current_queue:
await self._current_queue.put(outgoing_msg)
# Если очереди нет (клиент отменил запрос), но токены идут — шлем их в коллбек
elif self.callback:
self.callback(outgoing_msg)
else:
logger.warning(
f"[{self.id}] AgentEvent without active request")
elif isinstance(outgoing_msg, SM.EventEnd):
if self._current_queue:
await self._current_queue.put(outgoing_msg)
elif isinstance(outgoing_msg, SM.Error):
if self.callback:
self.callback(outgoing_msg)
error = AgentException(
outgoing_msg.code, outgoing_msg.details)
logger.error(f"[{self.id}] Agent error: {error}")
if self._current_queue:
await self._current_queue.put(error)
elif isinstance(outgoing_msg, SM.GracefulDisconnect):
if self.callback:
self.callback(outgoing_msg)
logger.info(
f"[{self.id}] Gracefully disconnecting")
break # Выход из цикла приведет к finally -> _cleanup
else:
# написать лог о неизвестном сообщении -> вывод неожиданного сообщения в лог
logger.warning(
f"[{self.id}] Unknown message type: {outgoing_msg.type}"
)
if self.callback:
self.callback(outgoing_msg)
except Exception as e:
logger.error(
f"[{self.id}] Failed to deserialize message: {e}")
if self._current_queue:
await self._current_queue.put(
AgentException(
"PARSE_ERROR", f"Validation failed: {e}")
)
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
logger.error(
f"[{self.id}] WebSocket closed/error: {msg.type}")
break
except asyncio.CancelledError:
pass # Нормальное прерывание при self.close()
except Exception as e:
logger.error(f"[{self.id}] Error in listen loop: {e}")
finally:
# Как только цикл прослушивания закончился, запускаем процедуру очистки
await self._cleanup()