219 lines
8.5 KiB
Python
219 lines
8.5 KiB
Python
import logging
|
||
from typing import AsyncGenerator
|
||
import aiohttp
|
||
from models import IM, OM, ClientMessage, ServerMessage
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AgentException(Exception):
|
||
"""
|
||
Кастомное исключение для ошибок, полученных от агента.
|
||
|
||
Атрибуты:
|
||
code: Код ошибки из сообщения OM.Error
|
||
details: Детали ошибки
|
||
"""
|
||
|
||
def __init__(self, code: str, details: str):
|
||
self.code = code
|
||
self.details = details
|
||
super().__init__(f"Agent error ({code}): {details}")
|
||
|
||
|
||
class AgentApi:
|
||
"""
|
||
Асинхронный клиент для взаимодействия с AI-агентом через WebSocket.
|
||
|
||
Класс инкапсулирует обмен сообщениями согласно контракту Pydantic
|
||
(IM для входящих, OM для исходящих сообщений).
|
||
|
||
Пример использования:
|
||
async with AgentApi("ws://localhost:8000") as agent:
|
||
await agent.send_user_message("Hello, agent!")
|
||
async for message in agent.listen():
|
||
print(message)
|
||
|
||
Атрибуты:
|
||
url: URL WebSocket сервера агента
|
||
_session: aiohttp ClientSession для WebSocket соединения
|
||
_ws: Активное WebSocket соединение
|
||
_connected: Флаг состояния соединения
|
||
"""
|
||
|
||
def __init__(self, url: str):
|
||
"""
|
||
Инициализирует клиент агента.
|
||
|
||
Аргументы:
|
||
url: URL WebSocket сервера (например, "ws://localhost:8000")
|
||
"""
|
||
self.url = url
|
||
self._session: aiohttp.ClientSession | None = None
|
||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||
self._connected = False
|
||
|
||
async def __aenter__(self):
|
||
"""
|
||
Входит в контекстный менеджер, устанавливает WebSocket соединение.
|
||
|
||
Возвращает:
|
||
self: Экземпляр AgentApi
|
||
|
||
Raises:
|
||
RuntimeError: Если не удалось установить соединение
|
||
"""
|
||
self._session = aiohttp.ClientSession()
|
||
try:
|
||
self._ws = await self._session.ws_connect(self.url, heartbeat=30)
|
||
self._connected = True
|
||
logger.info(f"Connected to agent at {self.url}")
|
||
except Exception as e:
|
||
await self._session.close()
|
||
raise RuntimeError(f"Failed to connect to agent: {e}") from e
|
||
|
||
return self
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
"""
|
||
Выходит из контекстного менеджера, закрывает WebSocket соединение.
|
||
|
||
Аргументы:
|
||
exc_type: Тип исключения, если оно произошло
|
||
exc_val: Значение исключения
|
||
exc_tb: Traceback исключения
|
||
"""
|
||
await self.close()
|
||
|
||
async def close(self):
|
||
"""
|
||
Закрывает WebSocket соединение и завершает сессию.
|
||
"""
|
||
self._connected = False
|
||
if self._ws and not self._ws.closed:
|
||
await self._ws.close()
|
||
logger.info("WebSocket connection closed")
|
||
|
||
if self._session:
|
||
await self._session.close()
|
||
logger.info("Client session closed")
|
||
|
||
async def send_user_message(self, text: str) -> None:
|
||
"""
|
||
Отправляет сообщение от пользователя на сервер.
|
||
|
||
Аргументы:
|
||
text: Текст сообщения пользователя
|
||
|
||
Raises:
|
||
RuntimeError: Если соединение не установлено
|
||
aiohttp.ClientError: Если не удалось отправить сообщение
|
||
"""
|
||
if not self._connected or not self._ws:
|
||
raise RuntimeError(
|
||
"Not connected to agent. Use 'async with' context manager.")
|
||
|
||
message = IM.UserMessage(
|
||
type=IM.Type.USER_MESSAGE,
|
||
text=text
|
||
)
|
||
|
||
try:
|
||
await self._ws.send_str(message.model_dump_json())
|
||
logger.debug(f"Sent user message: {text[:100]}...")
|
||
except Exception as e:
|
||
self._connected = False
|
||
raise aiohttp.ClientError(f"Failed to send message: {e}") from e
|
||
|
||
async def listen(self) -> AsyncGenerator[ServerMessage, None]:
|
||
"""
|
||
Читает поток сообщений от сервера и десериализует их.
|
||
|
||
Это асинхронный генератор, который обрабатывает:
|
||
- OM.Status: Логирует о готовности
|
||
- OM.Error: Бросает AgentException
|
||
- OM.GracefulDisconnect: Корректно завершает соединение
|
||
- OM.AgentEvent: Возвращает объект события
|
||
|
||
Yields:
|
||
ServerMessage: Десериализованное сообщение от сервера
|
||
|
||
Raises:
|
||
RuntimeError: Если соединение не установлено
|
||
AgentException: Если сервер отправил ошибку
|
||
|
||
Пример:
|
||
async for message in agent.listen():
|
||
if isinstance(message, OM.Status):
|
||
print("Agent is ready")
|
||
elif isinstance(message, OM.AgentEvent):
|
||
print(f"Agent event: {message.subtype}")
|
||
"""
|
||
if not self._connected or not self._ws:
|
||
raise RuntimeError(
|
||
"Not connected to agent. Use 'async with' context manager.")
|
||
|
||
try:
|
||
async for msg in self._ws:
|
||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||
try:
|
||
outgoing_msg = ServerMessage.model_validate_json(
|
||
msg.data)
|
||
logger.debug(
|
||
f"Received message of type: {outgoing_msg.type}")
|
||
|
||
# Обработка специальных событий
|
||
if isinstance(outgoing_msg, OM.Status):
|
||
logger.info("Agent is ready to accept messages")
|
||
yield outgoing_msg
|
||
|
||
elif isinstance(outgoing_msg, OM.Error):
|
||
error = AgentException(
|
||
outgoing_msg.code, outgoing_msg.details)
|
||
logger.error(f"Agent error: {error}")
|
||
raise error
|
||
|
||
elif isinstance(outgoing_msg, OM.GracefulDisconnect):
|
||
logger.info("Agent gracefully disconnecting")
|
||
self._connected = False
|
||
yield outgoing_msg
|
||
break
|
||
|
||
else:
|
||
# OM.AgentEvent и другие типы
|
||
yield outgoing_msg
|
||
|
||
except AgentException:
|
||
# Пробрасываем исключения агента дальше
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Failed to deserialize message: {e}")
|
||
raise RuntimeError(
|
||
f"Message deserialization error: {e}") from e
|
||
|
||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||
error_msg = f"WebSocket error: {self._ws.exception()}"
|
||
logger.error(error_msg)
|
||
self._connected = False
|
||
raise RuntimeError(error_msg)
|
||
|
||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||
logger.info("WebSocket connection closed by server")
|
||
self._connected = False
|
||
break
|
||
|
||
except AgentException:
|
||
raise
|
||
except RuntimeError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error in listen loop: {e}")
|
||
self._connected = False
|
||
raise RuntimeError(f"Listen loop error: {e}") from e
|
||
finally:
|
||
# Do not reset _connected here - it should only be reset on:
|
||
# - explicit close() call
|
||
# - OM.GracefulDisconnect handling
|
||
# - actual connection errors
|
||
pass
|