agent_api/api/agent_api.py

250 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from typing import Callable, Optional
import aiohttp
import asyncio
from models import IM, OM, ClientMessage, ServerMessage
logger = logging.getLogger(__name__)
class AgentException(Exception):
"""
Кастомное исключение для ошибок, полученных от агента.
Атрибуты:
code: Код ошибки из сообщения OM.Error
details: Детали ошибки
"""
def __init__(self, code: str, details: str):
self.code = code
self.details = details
super().__init__(f"Agent error ({code}): {details}")
class AgentApi:
"""
Асинхронный клиент для взаимодействия с AI-агентом через WebSocket.
Класс инкапсулирует обмен сообщениями согласно контракту Pydantic
(IM для входящих, OM для исходящих сообщений).
Пример использования:
async with AgentApi("ws://localhost:8000", callback=my_callback) as agent:
response = await agent.send_message("Hello, agent!")
async for chunk in response:
match chunk:
case OM.EventTextChunk():
print(chunk.text, end="")
print(f" [{response.tokens} токенов]")
Атрибуты:
url: URL WebSocket сервера агента
callback: Функция обратного вызова для событий не в процессе генерации
_session: aiohttp ClientSession для WebSocket соединения
_ws: Активное WebSocket соединение
_connected: Флаг состояния соединения
_queue: Очередь для передачи чанков ответа
_listen_task: Задача для прослушивания WS
"""
def __init__(self, url: str, callback: Optional[Callable[[ServerMessage], None]] = None):
"""
Инициализирует клиент агента.
Аргументы:
url: URL WebSocket сервера (например, "ws://localhost:8000")
callback: Функция обратного вызова для событий не в процессе генерации ответа
"""
self.url = url
self.callback = callback
self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._connected = False
self._queue = asyncio.Queue()
self._listen_task: asyncio.Task | None = None
async def __aenter__(self):
"""
Входит в контекстный менеджер, устанавливает WebSocket соединение.
Возвращает:
self: Экземпляр AgentApi
Raises:
RuntimeError: Если не удалось установить соединение
"""
self._session = aiohttp.ClientSession()
try:
self._ws = await self._session.ws_connect(self.url, heartbeat=30)
self._connected = True
logger.info(f"Connected to agent at {self.url}")
# Ожидаем OM.Status при открытии соединения
msg = await self._ws.receive()
if msg.type == aiohttp.WSMsgType.TEXT:
status_msg = ServerMessage.model_validate_json(msg.data)
if isinstance(status_msg, OM.Status):
if self.callback:
self.callback(status_msg)
logger.info("Agent is ready to accept messages")
else:
raise RuntimeError(
f"Expected OM.Status on connection, got {status_msg.type}")
else:
raise RuntimeError(
f"Unexpected message type on connection: {msg.type}")
# Запускаем задачу для прослушивания WS
self._listen_task = asyncio.create_task(self._listen())
except Exception as e:
await self._session.close()
raise RuntimeError(f"Failed to connect to agent: {e}") from e
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""
Выходит из контекстного менеджера, закрывает WebSocket соединение.
Аргументы:
exc_type: Тип исключения, если оно произошло
exc_val: Значение исключения
exc_tb: Traceback исключения
"""
await self.close()
async def close(self):
"""
Закрывает WebSocket соединение и завершает сессию.
"""
self._connected = False
if self._listen_task and not self._listen_task.done():
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
pass
if self._ws and not self._ws.closed:
await self._ws.close()
logger.info("WebSocket connection closed")
if self._session:
await self._session.close()
logger.info("Client session closed")
async def send_message(self, text: str) -> 'ResponseIterator':
"""
Отправляет сообщение от пользователя на сервер и возвращает итератор для получения ответа.
Аргументы:
text: Текст сообщения пользователя
Возвращает:
ResponseIterator: Объект для асинхронной итерации по чанкам ответа
Raises:
RuntimeError: Если соединение не установлено
aiohttp.ClientError: Если не удалось отправить сообщение
"""
if not self._connected or not self._ws:
raise RuntimeError(
"Not connected to agent. Use 'async with' context manager.")
message = IM.UserMessage(
type=IM.Type.USER_MESSAGE,
text=text
)
try:
await self._ws.send_str(message.model_dump_json())
logger.debug(f"Sent user message: {text[:100]}...")
except Exception as e:
self._connected = False
raise aiohttp.ClientError(f"Failed to send message: {e}") from e
return ResponseIterator(self._queue)
async def _listen(self):
"""
Прослушивает WebSocket в фоне и обрабатывает сообщения.
"""
try:
async for msg in self._ws:
if msg.type == aiohttp.WSMsgType.TEXT:
try:
outgoing_msg = ServerMessage.model_validate_json(
msg.data)
logger.debug(
f"Received message of type: {outgoing_msg.type}")
if isinstance(outgoing_msg, OM.AgentEvent):
await self._queue.put(outgoing_msg)
elif isinstance(outgoing_msg, OM.Status):
if self.callback:
self.callback(outgoing_msg)
logger.info("Agent status update")
elif isinstance(outgoing_msg, OM.Error):
if self.callback:
self.callback(outgoing_msg)
error = AgentException(
outgoing_msg.code, outgoing_msg.details)
logger.error(f"Agent error: {error}")
# Не бросаем исключение, а вызываем callback
elif isinstance(outgoing_msg, OM.GracefulDisconnect):
if self.callback:
self.callback(outgoing_msg)
logger.info("Agent gracefully disconnecting")
self._connected = False
break
else:
# Другие типы через callback
if self.callback:
self.callback(outgoing_msg)
except Exception as e:
logger.error(f"Failed to deserialize message: {e}")
if self.callback:
# Можно создать специальное сообщение об ошибке
pass
elif msg.type == aiohttp.WSMsgType.ERROR:
error_msg = f"WebSocket error: {self._ws.exception()}"
logger.error(error_msg)
self._connected = False
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
logger.info("WebSocket connection closed by server")
self._connected = False
break
except Exception as e:
logger.error(f"Error in listen loop: {e}")
self._connected = False
class ResponseIterator:
"""
Асинхронный итератор для получения чанков ответа от агента.
"""
def __init__(self, queue: asyncio.Queue):
self._queue = queue
self.tokens = 0
def __aiter__(self):
return self
async def __anext__(self):
try:
chunk = await self._queue.get()
if isinstance(chunk, OM.EventEnd):
self.tokens = chunk.tokens_used
raise StopAsyncIteration
return chunk
except asyncio.CancelledError:
raise StopAsyncIteration