#5 Реализовать клиентскую часть канала общения. Переминованы OutgoingMessage в ServerMessage и IncomingMessage в ClientMessage, OM.Status обрабатывается при открытии соединения. Реализован метод send_message.
This commit is contained in:
parent
c7a9d3446d
commit
8fd5c462ed
1 changed files with 93 additions and 62 deletions
155
api/agent_api.py
155
api/agent_api.py
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from typing import Callable, Optional
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from models import IM, OM, ClientMessage, ServerMessage
|
||||
|
||||
|
||||
|
|
@ -30,29 +31,39 @@ class AgentApi:
|
|||
(IM для входящих, OM для исходящих сообщений).
|
||||
|
||||
Пример использования:
|
||||
async with AgentApi("ws://localhost:8000") as agent:
|
||||
await agent.send_user_message("Hello, agent!")
|
||||
async for message in agent.listen():
|
||||
print(message)
|
||||
async with AgentApi("ws://localhost:8000", callback=my_callback) as agent:
|
||||
response = await agent.send_message("Hello, agent!")
|
||||
async for chunk in response:
|
||||
match chunk:
|
||||
case OM.EventTextChunk():
|
||||
print(chunk.text, end="")
|
||||
print(f" [{response.tokens} токенов]")
|
||||
|
||||
Атрибуты:
|
||||
url: URL WebSocket сервера агента
|
||||
callback: Функция обратного вызова для событий не в процессе генерации
|
||||
_session: aiohttp ClientSession для WebSocket соединения
|
||||
_ws: Активное WebSocket соединение
|
||||
_connected: Флаг состояния соединения
|
||||
_queue: Очередь для передачи чанков ответа
|
||||
_listen_task: Задача для прослушивания WS
|
||||
"""
|
||||
|
||||
def __init__(self, url: str):
|
||||
def __init__(self, url: str, callback: Optional[Callable[[ServerMessage], None]] = None):
|
||||
"""
|
||||
Инициализирует клиент агента.
|
||||
|
||||
Аргументы:
|
||||
url: URL WebSocket сервера (например, "ws://localhost:8000")
|
||||
callback: Функция обратного вызова для событий не в процессе генерации ответа
|
||||
"""
|
||||
self.url = url
|
||||
self.callback = callback
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._connected = False
|
||||
self._queue = asyncio.Queue()
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""
|
||||
|
|
@ -69,6 +80,25 @@ class AgentApi:
|
|||
self._ws = await self._session.ws_connect(self.url, heartbeat=30)
|
||||
self._connected = True
|
||||
logger.info(f"Connected to agent at {self.url}")
|
||||
|
||||
# Ожидаем OM.Status при открытии соединения
|
||||
msg = await self._ws.receive()
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
status_msg = ServerMessage.model_validate_json(msg.data)
|
||||
if isinstance(status_msg, OM.Status):
|
||||
if self.callback:
|
||||
self.callback(status_msg)
|
||||
logger.info("Agent is ready to accept messages")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Expected OM.Status on connection, got {status_msg.type}")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unexpected message type on connection: {msg.type}")
|
||||
|
||||
# Запускаем задачу для прослушивания WS
|
||||
self._listen_task = asyncio.create_task(self._listen())
|
||||
|
||||
except Exception as e:
|
||||
await self._session.close()
|
||||
raise RuntimeError(f"Failed to connect to agent: {e}") from e
|
||||
|
|
@ -91,6 +121,13 @@ class AgentApi:
|
|||
Закрывает WebSocket соединение и завершает сессию.
|
||||
"""
|
||||
self._connected = False
|
||||
if self._listen_task and not self._listen_task.done():
|
||||
self._listen_task.cancel()
|
||||
try:
|
||||
await self._listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._ws and not self._ws.closed:
|
||||
await self._ws.close()
|
||||
logger.info("WebSocket connection closed")
|
||||
|
|
@ -99,13 +136,16 @@ class AgentApi:
|
|||
await self._session.close()
|
||||
logger.info("Client session closed")
|
||||
|
||||
async def send_user_message(self, text: str) -> None:
|
||||
async def send_message(self, text: str) -> 'ResponseIterator':
|
||||
"""
|
||||
Отправляет сообщение от пользователя на сервер.
|
||||
Отправляет сообщение от пользователя на сервер и возвращает итератор для получения ответа.
|
||||
|
||||
Аргументы:
|
||||
text: Текст сообщения пользователя
|
||||
|
||||
Возвращает:
|
||||
ResponseIterator: Объект для асинхронной итерации по чанкам ответа
|
||||
|
||||
Raises:
|
||||
RuntimeError: Если соединение не установлено
|
||||
aiohttp.ClientError: Если не удалось отправить сообщение
|
||||
|
|
@ -126,34 +166,12 @@ class AgentApi:
|
|||
self._connected = False
|
||||
raise aiohttp.ClientError(f"Failed to send message: {e}") from e
|
||||
|
||||
async def listen(self) -> AsyncGenerator[ServerMessage, None]:
|
||||
return ResponseIterator(self._queue)
|
||||
|
||||
async def _listen(self):
|
||||
"""
|
||||
Читает поток сообщений от сервера и десериализует их.
|
||||
|
||||
Это асинхронный генератор, который обрабатывает:
|
||||
- OM.Status: Логирует о готовности
|
||||
- OM.Error: Бросает AgentException
|
||||
- OM.GracefulDisconnect: Корректно завершает соединение
|
||||
- OM.AgentEvent: Возвращает объект события
|
||||
|
||||
Yields:
|
||||
ServerMessage: Десериализованное сообщение от сервера
|
||||
|
||||
Raises:
|
||||
RuntimeError: Если соединение не установлено
|
||||
AgentException: Если сервер отправил ошибку
|
||||
|
||||
Пример:
|
||||
async for message in agent.listen():
|
||||
if isinstance(message, OM.Status):
|
||||
print("Agent is ready")
|
||||
elif isinstance(message, OM.AgentEvent):
|
||||
print(f"Agent event: {message.subtype}")
|
||||
Прослушивает WebSocket в фоне и обрабатывает сообщения.
|
||||
"""
|
||||
if not self._connected or not self._ws:
|
||||
raise RuntimeError(
|
||||
"Not connected to agent. Use 'async with' context manager.")
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
|
|
@ -163,57 +181,70 @@ class AgentApi:
|
|||
logger.debug(
|
||||
f"Received message of type: {outgoing_msg.type}")
|
||||
|
||||
# Обработка специальных событий
|
||||
if isinstance(outgoing_msg, OM.Status):
|
||||
logger.info("Agent is ready to accept messages")
|
||||
yield outgoing_msg
|
||||
|
||||
if isinstance(outgoing_msg, OM.AgentEvent):
|
||||
await self._queue.put(outgoing_msg)
|
||||
elif isinstance(outgoing_msg, OM.Status):
|
||||
if self.callback:
|
||||
self.callback(outgoing_msg)
|
||||
logger.info("Agent status update")
|
||||
elif isinstance(outgoing_msg, OM.Error):
|
||||
if self.callback:
|
||||
self.callback(outgoing_msg)
|
||||
error = AgentException(
|
||||
outgoing_msg.code, outgoing_msg.details)
|
||||
logger.error(f"Agent error: {error}")
|
||||
raise error
|
||||
|
||||
# Не бросаем исключение, а вызываем callback
|
||||
elif isinstance(outgoing_msg, OM.GracefulDisconnect):
|
||||
if self.callback:
|
||||
self.callback(outgoing_msg)
|
||||
logger.info("Agent gracefully disconnecting")
|
||||
self._connected = False
|
||||
yield outgoing_msg
|
||||
break
|
||||
|
||||
else:
|
||||
# OM.AgentEvent и другие типы
|
||||
yield outgoing_msg
|
||||
# Другие типы через callback
|
||||
if self.callback:
|
||||
self.callback(outgoing_msg)
|
||||
|
||||
except AgentException:
|
||||
# Пробрасываем исключения агента дальше
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize message: {e}")
|
||||
raise RuntimeError(
|
||||
f"Message deserialization error: {e}") from e
|
||||
if self.callback:
|
||||
# Можно создать специальное сообщение об ошибке
|
||||
pass
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
error_msg = f"WebSocket error: {self._ws.exception()}"
|
||||
logger.error(error_msg)
|
||||
self._connected = False
|
||||
raise RuntimeError(error_msg)
|
||||
break
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.info("WebSocket connection closed by server")
|
||||
self._connected = False
|
||||
break
|
||||
|
||||
except AgentException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in listen loop: {e}")
|
||||
self._connected = False
|
||||
raise RuntimeError(f"Listen loop error: {e}") from e
|
||||
finally:
|
||||
# Do not reset _connected here - it should only be reset on:
|
||||
# - explicit close() call
|
||||
# - OM.GracefulDisconnect handling
|
||||
# - actual connection errors
|
||||
pass
|
||||
|
||||
|
||||
class ResponseIterator:
|
||||
"""
|
||||
Асинхронный итератор для получения чанков ответа от агента.
|
||||
"""
|
||||
|
||||
def __init__(self, queue: asyncio.Queue):
|
||||
self._queue = queue
|
||||
self.tokens = 0
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self._queue.get()
|
||||
if isinstance(chunk, OM.EventEnd):
|
||||
self.tokens = chunk.tokens_used
|
||||
raise StopAsyncIteration
|
||||
return chunk
|
||||
except asyncio.CancelledError:
|
||||
raise StopAsyncIteration
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue