#5 Реализовать клиентскую часть канала общения. Исправлены все недочеты
This commit is contained in:
parent
0c19934008
commit
02467a4e12
1 changed files with 198 additions and 163 deletions
361
api/agent_api.py
361
api/agent_api.py
|
|
@ -1,176 +1,221 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, AsyncIterator
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
from models import CM, SM, ClientMessage, ServerMessage
|
from models import CM, SM, ClientMessage, ServerMessage
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AgentException(Exception):
|
class AgentException(Exception):
|
||||||
"""
|
|
||||||
Кастомное исключение для ошибок, полученных от агента.
|
|
||||||
|
|
||||||
Атрибуты:
|
|
||||||
code: Код ошибки из сообщения SM.Error
|
|
||||||
details: Детали ошибки
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, code: str, details: str):
|
def __init__(self, code: str, details: str):
|
||||||
self.code = code
|
self.code = code
|
||||||
self.details = details
|
self.details = details
|
||||||
super().__init__(f"Agent error ({code}): {details}")
|
super().__init__(f"Agent error ({code}): {details}")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentBusyException(AgentException):
|
||||||
|
def __init__(self, details: str):
|
||||||
|
super().__init__("BUSY", details)
|
||||||
|
|
||||||
|
|
||||||
class AgentApi:
|
class AgentApi:
|
||||||
"""
|
def __init__(
|
||||||
Асинхронный клиент для взаимодействия с AI-агентом через WebSocket.
|
self,
|
||||||
|
agent_id: str,
|
||||||
Класс инкапсулирует обмен сообщениями согласно контракту Pydantic
|
url: str,
|
||||||
(CM для входящих, SM для исходящих сообщений).
|
callback: Optional[Callable[[ServerMessage], None]] = None,
|
||||||
|
on_disconnect: Optional[Callable[['AgentApi'], None]] = None
|
||||||
Пример использования:
|
):
|
||||||
async with AgentApi("ws://localhost:8000", callback=my_callback) as agent:
|
self.id = agent_id # ID агента для словаря
|
||||||
response = await agent.send_message("Hello, agent!")
|
|
||||||
async for chunk in response:
|
|
||||||
match chunk:
|
|
||||||
case SM.EventTextChunk():
|
|
||||||
print(chunk.text, end="")
|
|
||||||
print(f" [{response.tokens} токенов]")
|
|
||||||
|
|
||||||
Атрибуты:
|
|
||||||
url: URL WebSocket сервера агента
|
|
||||||
callback: Функция обратного вызова для событий не в процессе генерации
|
|
||||||
_session: aiohttp ClientSession для WebSocket соединения
|
|
||||||
_ws: Активное WebSocket соединение
|
|
||||||
_connected: Флаг состояния соединения
|
|
||||||
_queue: Очередь для передачи чанков ответа
|
|
||||||
_listen_task: Задача для прослушивания WS
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, url: str, callback: Optional[Callable[[ServerMessage], None]] = None):
|
|
||||||
"""
|
|
||||||
Инициализирует клиент агента.
|
|
||||||
|
|
||||||
Аргументы:
|
|
||||||
url: URL WebSocket сервера (например, "ws://localhost:8000")
|
|
||||||
callback: Функция обратного вызова для событий не в процессе генерации ответа
|
|
||||||
"""
|
|
||||||
self.url = url
|
self.url = url
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
self.on_disconnect = on_disconnect
|
||||||
|
|
||||||
self._session: aiohttp.ClientSession | None = None
|
self._session: aiohttp.ClientSession | None = None
|
||||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self._queue = asyncio.Queue()
|
|
||||||
|
self._current_queue: asyncio.Queue | None = None
|
||||||
|
self._request_lock = asyncio.Lock()
|
||||||
self._listen_task: asyncio.Task | None = None
|
self._listen_task: asyncio.Task | None = None
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def connect(self):
|
||||||
"""
|
"""Явное подключение к агенту."""
|
||||||
Входит в контекстный менеджер, устанавливает WebSocket соединение.
|
|
||||||
|
|
||||||
Возвращает:
|
|
||||||
self: Экземпляр AgentApi
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: Если не удалось установить соединение
|
|
||||||
"""
|
|
||||||
self._session = aiohttp.ClientSession()
|
self._session = aiohttp.ClientSession()
|
||||||
try:
|
try:
|
||||||
self._ws = await self._session.ws_connect(self.url, heartbeat=30)
|
self._ws = await self._session.ws_connect(self.url, heartbeat=30)
|
||||||
self._connected = True
|
|
||||||
logger.info(f"Connected to agent at {self.url}")
|
|
||||||
|
|
||||||
# Ожидаем SM.Status при открытии соединения
|
# Используем asyncio.wait_for вместо timeout в receive
|
||||||
msg = await self._ws.receive()
|
# здесь отлавливаем исключение
|
||||||
|
msg = await asyncio.wait_for(self._ws.receive(), timeout=5.0)
|
||||||
|
|
||||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||||
status_msg = ServerMessage.model_validate_json(msg.data)
|
status_msg = ServerMessage.model_validate_json(msg.data)
|
||||||
if isinstance(status_msg, SM.Status):
|
if isinstance(status_msg, SM.Status):
|
||||||
if self.callback:
|
logger.info(f"Agent {self.id} is ready")
|
||||||
self.callback(status_msg)
|
|
||||||
logger.info("Agent is ready to accept messages")
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise AgentException(
|
||||||
f"Expected SM.Status on connection, got {status_msg.type}")
|
"INVALID_STATUS", f"Expected SM.Status, got {status_msg.type}")
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise AgentException("UNEXPECTED_MSG_TYPE",
|
||||||
f"Unexpected message type on connection: {msg.type}")
|
f"Unexpected message type: {msg.type}")
|
||||||
|
|
||||||
# Запускаем задачу для прослушивания WS
|
self._connected = True
|
||||||
self._listen_task = asyncio.create_task(self._listen())
|
self._listen_task = asyncio.create_task(self._listen())
|
||||||
|
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
|
# Специфичная обработка таймаута
|
||||||
|
if self._ws and not self._ws.closed:
|
||||||
|
await self._ws.close()
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
raise AgentException(
|
||||||
|
"TIMEOUT", "Agent did not send initial Status message within 5 seconds") from e
|
||||||
|
|
||||||
|
except AgentException:
|
||||||
|
# Перехватываем наши собственные ошибки (INVALID_STATUS, UNEXPECTED_MSG_TYPE),
|
||||||
|
# закрываем ресурсы и пробрасываем ошибку дальше без изменений!
|
||||||
|
if self._ws and not self._ws.closed:
|
||||||
|
await self._ws.close()
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
raise
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self._session.close()
|
# Обработка всех остальных ошибок (например, aiohttp.ClientConnectionError)
|
||||||
raise RuntimeError(f"Failed to connect to agent: {e}") from e
|
if self._ws and not self._ws.closed:
|
||||||
|
await self._ws.close()
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
return self
|
# Можно оставить RuntimeError, а можно тоже завернуть в AgentException
|
||||||
|
raise AgentException(code="CONNECTION_ERROR",
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
details=f"Failed to connect agent {self.id}: {e}") from e
|
||||||
"""
|
|
||||||
Выходит из контекстного менеджера, закрывает WebSocket соединение.
|
|
||||||
|
|
||||||
Аргументы:
|
|
||||||
exc_type: Тип исключения, если оно произошло
|
|
||||||
exc_val: Значение исключения
|
|
||||||
exc_tb: Traceback исключения
|
|
||||||
"""
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""
|
"""Явное ручное закрытие соединения."""
|
||||||
Закрывает WebSocket соединение и завершает сессию.
|
|
||||||
"""
|
|
||||||
self._connected = False
|
|
||||||
if self._listen_task and not self._listen_task.done():
|
if self._listen_task and not self._listen_task.done():
|
||||||
|
# Это вызовет CancelledError внутри _listen и запустит finally
|
||||||
self._listen_task.cancel()
|
self._listen_task.cancel()
|
||||||
try:
|
try:
|
||||||
await self._listen_task
|
await self._listen_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
else:
|
||||||
|
await self._cleanup()
|
||||||
|
|
||||||
|
async def _cleanup(self):
|
||||||
|
logger.info(f"Cleaning up agent {self.id}")
|
||||||
|
self._connected = False # Сбрасываем в любом случае
|
||||||
|
|
||||||
|
if self._current_queue:
|
||||||
|
await self._current_queue.put(ConnectionError("Connection closed"))
|
||||||
|
|
||||||
if self._ws and not self._ws.closed:
|
if self._ws and not self._ws.closed:
|
||||||
await self._ws.close()
|
await self._ws.close()
|
||||||
logger.info("WebSocket connection closed")
|
|
||||||
|
|
||||||
if self._session:
|
if self._session and not self._session.closed:
|
||||||
await self._session.close()
|
await self._session.close()
|
||||||
logger.info("Client session closed")
|
# aiohttp рекомендует небольшую паузу для корректного закрытия TCP соединений под капотом
|
||||||
|
await asyncio.sleep(0.0)
|
||||||
|
|
||||||
async def send_message(self, text: str) -> 'ResponseIterator':
|
if self.on_disconnect:
|
||||||
|
try:
|
||||||
|
# Важно: callback синхронный, он не должен делать тяжелых IO операций
|
||||||
|
self.on_disconnect(self)
|
||||||
|
self.on_disconnect = None # Защита от двойного вызова
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in on_disconnect: {e}")
|
||||||
|
|
||||||
|
async def send_message(self, text: str) -> AsyncIterator[SM.AgentEvent]:
|
||||||
"""
|
"""
|
||||||
Отправляет сообщение от пользователя на сервер и возвращает итератор для получения ответа.
|
Нативный асинхронный генератор.
|
||||||
|
Не требует отдельного класса ResponseIterator.
|
||||||
Аргументы:
|
Гарантированно освобождает блокировку.
|
||||||
text: Текст сообщения пользователя
|
|
||||||
|
|
||||||
Возвращает:
|
|
||||||
ResponseIterator: Объект для асинхронной итерации по чанкам ответа
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: Если соединение не установлено
|
|
||||||
aiohttp.ClientError: Если не удалось отправить сообщение
|
|
||||||
"""
|
"""
|
||||||
if not self._connected or not self._ws:
|
if not self._connected or not self._ws:
|
||||||
raise RuntimeError(
|
raise AgentException(code="NOT_CONNECTED",
|
||||||
"Not connected to agent. Use 'async with' context manager.")
|
details="Not connected. Call connect() first.")
|
||||||
|
|
||||||
message = CM.UserMessage(
|
if self._request_lock.locked():
|
||||||
type=CM.Type.USER_MESSAGE,
|
raise AgentBusyException(
|
||||||
text=text
|
"Agent is currently processing another request")
|
||||||
)
|
|
||||||
|
|
||||||
|
# Блокируем параллельные запросы
|
||||||
|
# если идет стриминг ответа, то при попытки отправить новое сообщение будет ошибка - ее рейзим(делаем AgentBusyError)
|
||||||
|
# храню состояние занят ли агент или нет self.is_busy - если занят, то кидаем исключение (состояние занятости агента хранится lock)
|
||||||
|
await self._request_lock.acquire()
|
||||||
try:
|
try:
|
||||||
await self._ws.send_str(message.model_dump_json())
|
self._current_queue = asyncio.Queue()
|
||||||
logger.debug(f"Sent user message: {text[:100]}...")
|
|
||||||
except Exception as e:
|
|
||||||
self._connected = False
|
|
||||||
raise aiohttp.ClientError(f"Failed to send message: {e}") from e
|
|
||||||
|
|
||||||
return ResponseIterator(self._queue)
|
message = CM.UserMessage(
|
||||||
|
type=CM.Type.USER_MESSAGE,
|
||||||
|
text=text
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._ws.send_str(message.model_dump_json())
|
||||||
|
logger.debug(f"[{self.id}] Sent message: {text[:50]}...")
|
||||||
|
|
||||||
|
# Читаем ответы из очереди
|
||||||
|
while True:
|
||||||
|
chunk = await self._current_queue.get()
|
||||||
|
|
||||||
|
# Если прилетела ошибка (от агента или из _listen)
|
||||||
|
if isinstance(chunk, Exception):
|
||||||
|
raise chunk
|
||||||
|
|
||||||
|
# Если конец ответа
|
||||||
|
if isinstance(chunk, SM.EventEnd):
|
||||||
|
# Если вам нужны tokens_used, можно сохранить их в атрибут self
|
||||||
|
break
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Если наш сервер (бэкенд агента) поддерживает команду отмены (например, специальный CM.Type.STOP_MESSAGE), то блок finally должен выглядеть так:
|
||||||
|
# if self._ws and not self._ws.closed:
|
||||||
|
# # Отправляем серверу команду "Хватит генерировать!"
|
||||||
|
# stop_msg = CM.UserMessage(type=CM.Type.STOP_GENERATION, text="")
|
||||||
|
# # (используйте правильную структуру вашего CM)
|
||||||
|
|
||||||
|
# # Запускаем отправку как таску, чтобы не блокировать finally
|
||||||
|
# asyncio.create_task(self._ws.send_str(stop_msg.model_dump_json()))
|
||||||
|
if self._current_queue:
|
||||||
|
# 1. Сохраняем ссылку на локальную переменную и отвязываем от класса.
|
||||||
|
# Это гарантирует, что _listen больше не положит сюда новые сообщения.
|
||||||
|
orphan_queue = self._current_queue
|
||||||
|
self._current_queue = None
|
||||||
|
|
||||||
|
# 2. Вычищаем (drain) очередь от сообщений, которые клиент не успел забрать
|
||||||
|
while not orphan_queue.empty():
|
||||||
|
try:
|
||||||
|
orphan_msg = orphan_queue.get_nowait()
|
||||||
|
|
||||||
|
# Если это исключение, просто логируем, в коллбек его кидать не стоит
|
||||||
|
if isinstance(orphan_msg, Exception):
|
||||||
|
logger.debug(
|
||||||
|
f"[{self.id}] Dropped exception from queue during cleanup: {orphan_msg}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 3. Отправляем "мусорные/осиротевшие" куски в callback
|
||||||
|
if self.callback:
|
||||||
|
self.callback(orphan_msg)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"[{self.id}] Dropped orphaned message during cleanup")
|
||||||
|
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Освобождаем лок, чтобы разрешить новые запросы
|
||||||
|
if self._request_lock.locked():
|
||||||
|
self._request_lock.release()
|
||||||
|
|
||||||
async def _listen(self):
|
async def _listen(self):
|
||||||
"""
|
""""
|
||||||
Прослушивает WebSocket в фоне и обрабатывает сообщения.
|
Прослушивание вебсокета.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async for msg in self._ws:
|
async for msg in self._ws:
|
||||||
|
|
@ -178,73 +223,63 @@ class AgentApi:
|
||||||
try:
|
try:
|
||||||
outgoing_msg = ServerMessage.model_validate_json(
|
outgoing_msg = ServerMessage.model_validate_json(
|
||||||
msg.data)
|
msg.data)
|
||||||
logger.debug(
|
|
||||||
f"Received message of type: {outgoing_msg.type}")
|
|
||||||
|
|
||||||
if isinstance(outgoing_msg, SM.AgentEvent):
|
if isinstance(outgoing_msg, SM.AgentEvent):
|
||||||
await self._queue.put(outgoing_msg)
|
if self._current_queue:
|
||||||
elif isinstance(outgoing_msg, SM.Status):
|
await self._current_queue.put(outgoing_msg)
|
||||||
if self.callback:
|
# Если очереди нет (клиент отменил запрос), но токены идут — шлем их в коллбек
|
||||||
|
elif self.callback:
|
||||||
self.callback(outgoing_msg)
|
self.callback(outgoing_msg)
|
||||||
logger.info("Agent status update")
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.id}] AgentEvent without active request")
|
||||||
|
|
||||||
|
elif isinstance(outgoing_msg, SM.EventEnd):
|
||||||
|
if self._current_queue:
|
||||||
|
await self._current_queue.put(outgoing_msg)
|
||||||
|
|
||||||
elif isinstance(outgoing_msg, SM.Error):
|
elif isinstance(outgoing_msg, SM.Error):
|
||||||
if self.callback:
|
if self.callback:
|
||||||
self.callback(outgoing_msg)
|
self.callback(outgoing_msg)
|
||||||
error = AgentException(
|
error = AgentException(
|
||||||
outgoing_msg.code, outgoing_msg.details)
|
outgoing_msg.code, outgoing_msg.details)
|
||||||
logger.error(f"Agent error: {error}")
|
logger.error(f"[{self.id}] Agent error: {error}")
|
||||||
# Не бросаем исключение, а вызываем callback
|
if self._current_queue:
|
||||||
|
await self._current_queue.put(error)
|
||||||
|
|
||||||
elif isinstance(outgoing_msg, SM.GracefulDisconnect):
|
elif isinstance(outgoing_msg, SM.GracefulDisconnect):
|
||||||
if self.callback:
|
if self.callback:
|
||||||
self.callback(outgoing_msg)
|
self.callback(outgoing_msg)
|
||||||
logger.info("Agent gracefully disconnecting")
|
logger.info(
|
||||||
self._connected = False
|
f"[{self.id}] Gracefully disconnecting")
|
||||||
break
|
break # Выход из цикла приведет к finally -> _cleanup
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Другие типы через callback
|
# написать лог о неизвестном сообщении -> вывод неожиданного сообщения в лог
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.id}] Unknown message type: {outgoing_msg.type}"
|
||||||
|
)
|
||||||
if self.callback:
|
if self.callback:
|
||||||
self.callback(outgoing_msg)
|
self.callback(outgoing_msg)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to deserialize message: {e}")
|
logger.error(
|
||||||
if self.callback:
|
f"[{self.id}] Failed to deserialize message: {e}")
|
||||||
# Можно создать специальное сообщение об ошибке
|
if self._current_queue:
|
||||||
pass
|
await self._current_queue.put(
|
||||||
|
AgentException(
|
||||||
|
"PARSE_ERROR", f"Validation failed: {e}")
|
||||||
|
)
|
||||||
|
|
||||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
|
||||||
error_msg = f"WebSocket error: {self._ws.exception()}"
|
logger.error(
|
||||||
logger.error(error_msg)
|
f"[{self.id}] WebSocket closed/error: {msg.type}")
|
||||||
self._connected = False
|
|
||||||
break
|
break
|
||||||
|
|
||||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
|
||||||
logger.info("WebSocket connection closed by server")
|
|
||||||
self._connected = False
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in listen loop: {e}")
|
|
||||||
self._connected = False
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseIterator:
|
|
||||||
"""
|
|
||||||
Асинхронный итератор для получения чанков ответа от агента.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, queue: asyncio.Queue):
|
|
||||||
self._queue = queue
|
|
||||||
self.tokens = 0
|
|
||||||
|
|
||||||
def __aiter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self):
|
|
||||||
try:
|
|
||||||
chunk = await self._queue.get()
|
|
||||||
if isinstance(chunk, SM.EventEnd):
|
|
||||||
self.tokens = chunk.tokens_used
|
|
||||||
raise StopAsyncIteration
|
|
||||||
return chunk
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise StopAsyncIteration
|
pass # Нормальное прерывание при self.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[{self.id}] Error in listen loop: {e}")
|
||||||
|
finally:
|
||||||
|
# Как только цикл прослушивания закончился, запускаем процедуру очистки
|
||||||
|
await self._cleanup()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue