agent_api/lambda_agent_api/agent_api.py

331 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from typing import Callable, Optional, AsyncIterator
import aiohttp
import asyncio
from urllib.parse import urljoin
from aiohttp import WSServerHandshakeError
from lambda_agent_api.server import *
from lambda_agent_api.client import *
logger = logging.getLogger(__name__)
class AgentException(Exception):
def __init__(self, code: str, details: str):
self.code = code
self.details = details
super().__init__(f"Agent error ({code}): {details}")
class AgentBusyException(AgentException):
def __init__(self, details: str):
super().__init__("BUSY", details)
class AgentApi:
def __init__(
self,
agent_id: str,
base_url: str,
callback: Optional[Callable[[ServerMessage], None]] = None,
on_disconnect: Optional[Callable[["AgentApi"], None]] = None,
chat_id: int = 0, # значение по умолчанию для обратной совместимости # значение по умолчанию для обратной совместимости
):
self.id = agent_id # ID агента для словаря
self.chat_id = chat_id
self.url = urljoin(base_url, f"v1/agent_ws/{chat_id}/")
self.callback = callback
self.on_disconnect = on_disconnect
self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._connected = False
self._current_queue: asyncio.Queue | None = None
self._request_lock = asyncio.Lock()
self._listen_task: asyncio.Task | None = None
async def connect(self):
"""Явное подключение к агенту.
:raise AgentBusyException: Чат занят другим клиентом.
:raise AgentException: Непредвиденная ошибка протокола, см. code и details
"""
self._session = aiohttp.ClientSession()
try:
self._ws = await self._session.ws_connect(self.url, heartbeat=30)
# Используем asyncio.wait_for вместо timeout в receive
# здесь отлавливаем исключение
msg = await asyncio.wait_for(self._ws.receive(), timeout=5.0)
if msg.type == aiohttp.WSMsgType.TEXT:
status_msg = ServerMessage.validate_json(msg.data)
if isinstance(status_msg, MsgStatus):
logger.info(f"Agent {self.id} is ready")
else:
raise AgentException(
"INVALID_STATUS", f"Expected SM.Status, got {status_msg.type}"
)
else:
raise AgentException(
"UNEXPECTED_MSG_TYPE", f"Unexpected message type: {msg.type}"
)
self._connected = True
self._listen_task = asyncio.create_task(self._listen())
except asyncio.TimeoutError as e:
# Специфичная обработка таймаута
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
raise AgentException(
"TIMEOUT", "Agent did not send initial Status message within 5 seconds"
) from e
except AgentException:
# Перехватываем наши собственные ошибки (INVALID_STATUS, UNEXPECTED_MSG_TYPE),
# закрываем ресурсы и пробрасываем ошибку дальше без изменений!
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
raise
except WSServerHandshakeError as e: # если при открытии подключения сервер вернул какую-то ошибку
if self._session and not self._session.closed:
await self._session.close()
# во-первых, aiohttp зачем-то приводит WS коды ошибок к HTTP.
# во-вторых, делает он это некорректно. Любой неизвестный код WS ошибки становится 403 HTTP
# в-третьих, он не передает оригинальное сообщение об ошибке.
# Т. е. какое бы сообщение сервер не отправлял, тут всегда будет "Invalid response status"
# см. site-packages\aiohttp\client.py, line 1104, in _ws_connect
# итого понять реальную причину ошибки почти невозможно. Нужно менять библиотеку
# сейчас сервер специально кидает только ошибку с WS кодом 1008 Policy Violation, когда внутри ловит ChatBusyError
# поэтому скорее всего, если мы получили WSServerHandshakeError с 403, то это внутренний ChatBusyError
if e.status != 403:
# обрабатываем как обычную ошибку WS по примеру except блока ниже
raise AgentException(code="CONNECTION_ERROR",
details=f"Failed to connect agent {self.id}: {e}") from e
raise AgentBusyException(f"Chat {self.chat_id} is already in use by other client")
except Exception as e:
# Обработка всех остальных ошибок (например, aiohttp.ClientConnectionError)
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
# Можно оставить RuntimeError, а можно тоже завернуть в AgentException
raise AgentException(
code="CONNECTION_ERROR",
details=f"Failed to connect agent {self.id}: {e}",
) from e
async def close(self):
"""Явное ручное закрытие соединения."""
if self._listen_task and not self._listen_task.done():
# Это вызовет CancelledError внутри _listen и запустит finally
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
pass
else:
await self._cleanup()
async def _cleanup(self):
logger.info(f"Cleaning up agent {self.id}")
self._connected = False # Сбрасываем в любом случае
if self._current_queue:
await self._current_queue.put(ConnectionError("Connection closed"))
if self._ws and not self._ws.closed:
await self._ws.close()
if self._session and not self._session.closed:
await self._session.close()
# aiohttp рекомендует небольшую паузу для корректного закрытия TCP соединений под капотом
await asyncio.sleep(0.0)
if self.on_disconnect:
try:
# Важно: callback синхронный, он не должен делать тяжелых IO операций
self.on_disconnect(self)
self.on_disconnect = None # Защита от двойного вызова
except Exception as e:
logger.error(f"Error in on_disconnect: {e}")
async def send_message(
self, text: str, attachments: list[str] | None = None
) -> AsyncIterator[AgentEventUnion]:
"""
Нативный асинхронный генератор.
Не требует отдельного класса ResponseIterator.
Гарантированно освобождает блокировку.
Args:
text: Текст сообщения.
attachments: Список путей к файлам относительно /workspace.
"""
if not self._connected or not self._ws:
raise AgentException(
code="NOT_CONNECTED", details="Not connected. Call connect() first."
)
if self._request_lock.locked():
raise AgentBusyException("Agent is currently processing another request")
# Блокируем параллельные запросы
# если идет стриминг ответа, то при попытки отправить новое сообщение будет ошибка - ее рейзим(делаем AgentBusyError)
# храню состояние занят ли агент или нет self.is_busy - если занят, то кидаем исключение (состояние занятости агента хранится lock)
await self._request_lock.acquire()
try:
self._current_queue = asyncio.Queue()
message = MsgUserMessage(
type=EClientMessage.USER_MESSAGE,
text=text,
attachments=attachments or [],
)
await self._ws.send_str(message.model_dump_json())
logger.debug(f"[{self.id}] Sent message: {text[:50]}...")
# Читаем ответы из очереди
while True:
chunk = await self._current_queue.get()
# Если прилетела ошибка (от агента или из _listen)
if isinstance(chunk, Exception):
raise chunk
# Если конец ответа
if isinstance(chunk, MsgEventEnd):
# Если вам нужны tokens_used, можно сохранить их в атрибут self
break
yield chunk
finally:
# Если наш сервер (бэкенд агента) поддерживает команду отмены (например, специальный CM.Type.STOP_MESSAGE), то блок finally должен выглядеть так:
# if self._ws and not self._ws.closed:
# # Отправляем серверу команду "Хватит генерировать!"
# stop_msg = CM.UserMessage(type=CM.Type.STOP_GENERATION, text="")
# # (используйте правильную структуру вашего CM)
# # Запускаем отправку как таску, чтобы не блокировать finally
# asyncio.create_task(self._ws.send_str(stop_msg.model_dump_json()))
if self._current_queue:
# 1. Сохраняем ссылку на локальную переменную и отвязываем от класса.
# Это гарантирует, что _listen больше не положит сюда новые сообщения.
orphan_queue = self._current_queue
self._current_queue = None
# 2. Вычищаем (drain) очередь от сообщений, которые клиент не успел забрать
while not orphan_queue.empty():
try:
orphan_msg = orphan_queue.get_nowait()
# Если это исключение, просто логируем, в коллбек его кидать не стоит
if isinstance(orphan_msg, Exception):
logger.debug(
f"[{self.id}] Dropped exception from queue during cleanup: {orphan_msg}"
)
continue
# 3. Отправляем "мусорные/осиротевшие" куски в callback
if self.callback:
self.callback(orphan_msg)
else:
logger.debug(
f"[{self.id}] Dropped orphaned message during cleanup"
)
except asyncio.QueueEmpty:
break
# Освобождаем лок, чтобы разрешить новые запросы
if self._request_lock.locked():
self._request_lock.release()
async def _listen(self):
"""
Прослушивание вебсокета.
"""
try:
async for msg in self._ws:
if msg.type == aiohttp.WSMsgType.TEXT:
try:
outgoing_msg = ServerMessage.validate_json(msg.data)
if isinstance(
outgoing_msg,
(
MsgEventTextChunk,
MsgEventToolCallChunk,
MsgEventToolResult,
MsgEventCustomUpdate,
MsgEventSendFile,
MsgEventEnd,
),
):
if self._current_queue:
await self._current_queue.put(outgoing_msg)
elif self.callback:
self.callback(outgoing_msg)
else:
logger.warning(
f"[{self.id}] AgentEvent without active request"
)
elif isinstance(outgoing_msg, MsgError):
if self.callback:
self.callback(outgoing_msg)
error = AgentException(
outgoing_msg.code, outgoing_msg.details
)
logger.error(f"[{self.id}] Agent error: {error}")
if self._current_queue:
await self._current_queue.put(error)
elif isinstance(outgoing_msg, MsgGracefulDisconnect):
if self.callback:
self.callback(outgoing_msg)
logger.info(f"[{self.id}] Gracefully disconnecting")
break # Выход из цикла приведет к finally -> _cleanup
else:
# написать лог о неизвестном сообщении -> вывод неожиданного сообщения в лог
logger.warning(
f"[{self.id}] Unknown message type: {outgoing_msg.type}"
)
if self.callback:
self.callback(outgoing_msg)
except Exception as e:
logger.error(f"[{self.id}] Failed to deserialize message: {e}")
if self._current_queue:
await self._current_queue.put(
AgentException("PARSE_ERROR", f"Validation failed: {e}")
)
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
logger.error(f"[{self.id}] WebSocket closed/error: {msg.type}")
break
except asyncio.CancelledError:
pass # Нормальное прерывание при self.close()
except Exception as e:
logger.error(f"[{self.id}] Error in listen loop: {e}")
finally:
# Как только цикл прослушивания закончился, запускаем процедуру очистки
await self._cleanup()