корректные pydantic модели для автоматического определения класса по полю type
This commit is contained in:
parent
b34cbaf677
commit
1e256a545b
11 changed files with 294 additions and 148 deletions
|
|
@ -8,13 +8,12 @@
|
|||
"""
|
||||
|
||||
from lambda_agent_api.agent_api import AgentApi, AgentException
|
||||
from lambda_agent_api.models import CM, SM, ClientMessage, ServerMessage
|
||||
from lambda_agent_api.server import ServerMessage
|
||||
from lambda_agent_api.client import ClientMessage
|
||||
|
||||
__all__ = [
|
||||
"AgentApi",
|
||||
"AgentException",
|
||||
"CM",
|
||||
"SM",
|
||||
"ClientMessage",
|
||||
"ServerMessage",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from typing import Callable, Optional, AsyncIterator
|
|||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
from lambda_agent_api.models import CM, SM, ClientMessage, ServerMessage
|
||||
from lambda_agent_api.server import *
|
||||
from lambda_agent_api.client import *
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -52,8 +53,8 @@ class AgentApi:
|
|||
msg = await asyncio.wait_for(self._ws.receive(), timeout=5.0)
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
status_msg = ServerMessage.model_validate_json(msg.data)
|
||||
if isinstance(status_msg, SM.Status):
|
||||
status_msg = ServerMessage.validate_json(msg.data)
|
||||
if isinstance(status_msg, MsgStatus):
|
||||
logger.info(f"Agent {self.id} is ready")
|
||||
else:
|
||||
raise AgentException(
|
||||
|
|
@ -130,7 +131,7 @@ class AgentApi:
|
|||
except Exception as e:
|
||||
logger.error(f"Error in on_disconnect: {e}")
|
||||
|
||||
async def send_message(self, text: str) -> AsyncIterator[SM.AgentEvent]:
|
||||
async def send_message(self, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||
"""
|
||||
Нативный асинхронный генератор.
|
||||
Не требует отдельного класса ResponseIterator.
|
||||
|
|
@ -151,8 +152,8 @@ class AgentApi:
|
|||
try:
|
||||
self._current_queue = asyncio.Queue()
|
||||
|
||||
message = CM.UserMessage(
|
||||
type=CM.Type.USER_MESSAGE,
|
||||
message = MsgUserMessage(
|
||||
type=EClientMessage.USER_MESSAGE,
|
||||
text=text
|
||||
)
|
||||
|
||||
|
|
@ -168,7 +169,7 @@ class AgentApi:
|
|||
raise chunk
|
||||
|
||||
# Если конец ответа
|
||||
if isinstance(chunk, SM.EventEnd):
|
||||
if isinstance(chunk, MsgEventEnd):
|
||||
# Если вам нужны tokens_used, можно сохранить их в атрибут self
|
||||
break
|
||||
|
||||
|
|
@ -222,10 +223,10 @@ class AgentApi:
|
|||
async for msg in self._ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
outgoing_msg = ServerMessage.model_validate_json(
|
||||
outgoing_msg = ServerMessage.validate_json(
|
||||
msg.data)
|
||||
|
||||
if isinstance(outgoing_msg, SM.AgentEvent):
|
||||
if isinstance(outgoing_msg, MsgEventTextChunk):
|
||||
if self._current_queue:
|
||||
await self._current_queue.put(outgoing_msg)
|
||||
# Если очереди нет (клиент отменил запрос), но токены идут — шлем их в коллбек
|
||||
|
|
@ -235,11 +236,11 @@ class AgentApi:
|
|||
logger.warning(
|
||||
f"[{self.id}] AgentEvent without active request")
|
||||
|
||||
elif isinstance(outgoing_msg, SM.EventEnd):
|
||||
elif isinstance(outgoing_msg, MsgEventEnd):
|
||||
if self._current_queue:
|
||||
await self._current_queue.put(outgoing_msg)
|
||||
|
||||
elif isinstance(outgoing_msg, SM.Error):
|
||||
elif isinstance(outgoing_msg, MsgError):
|
||||
if self.callback:
|
||||
self.callback(outgoing_msg)
|
||||
error = AgentException(
|
||||
|
|
@ -248,7 +249,7 @@ class AgentApi:
|
|||
if self._current_queue:
|
||||
await self._current_queue.put(error)
|
||||
|
||||
elif isinstance(outgoing_msg, SM.GracefulDisconnect):
|
||||
elif isinstance(outgoing_msg, MsgGracefulDisconnect):
|
||||
if self.callback:
|
||||
self.callback(outgoing_msg)
|
||||
logger.info(
|
||||
|
|
|
|||
33
lambda_agent_api/client.py
Normal file
33
lambda_agent_api/client.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from typing import Annotated, Union, Literal
|
||||
|
||||
|
||||
__all__ = ['EClientMessage', 'MsgUserMessage', 'ClientMessage']
|
||||
|
||||
|
||||
class EClientMessage(str, Enum):
|
||||
USER_MESSAGE = "USER_MESSAGE"
|
||||
|
||||
|
||||
class MsgUserMessage(BaseModel):
|
||||
"""
|
||||
Полное сообщение от пользователя.
|
||||
"""
|
||||
type: Literal[EClientMessage.USER_MESSAGE] = EClientMessage.USER_MESSAGE
|
||||
text: str
|
||||
"""
|
||||
Текст сообщения.
|
||||
"""
|
||||
|
||||
|
||||
ClientMessage = TypeAdapter(Annotated[
|
||||
Union[MsgUserMessage,],
|
||||
Field(discriminator="type")
|
||||
])
|
||||
"""
|
||||
Объединяет все типы входящих сообщений в одно для удобной автоматической десериализации.\n
|
||||
Pydantic сам определит нужный тип в зависимости от поля ``type``.\n
|
||||
Использование:\n
|
||||
msg = ClientMessage.model_validate_json(json)
|
||||
"""
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
from typing import Literal, Annotated, Union
|
||||
|
||||
|
||||
class CM:
|
||||
"""
|
||||
Namespace для моделей входящих сообщений (от клиента к серверу).\n
|
||||
CM = Client Message
|
||||
"""
|
||||
|
||||
class Type(str, Enum):
|
||||
USER_MESSAGE = "USER_MESSAGE"
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
"""
|
||||
Полное сообщение от пользователя.
|
||||
"""
|
||||
type: Literal[CM.Type.USER_MESSAGE]
|
||||
text: str
|
||||
"""
|
||||
Текст сообщения.
|
||||
"""
|
||||
|
||||
|
||||
ClientMessage = Annotated[
|
||||
Union[CM.UserMessage,],
|
||||
Field(discriminator="type")
|
||||
]
|
||||
"""
|
||||
Объединяет все типы входящих сообщений в одно для удобной автоматической десериализации.\n
|
||||
Pydantic сам определит нужный тип в зависимости от поля ``type``.\n
|
||||
Использование:\n
|
||||
msg = ClientMessage.model_validate_json(json)
|
||||
"""
|
||||
|
||||
|
||||
class SM:
|
||||
"""
|
||||
Namespace для моделей исходящих сообщений (от сервера к клиенту).\n
|
||||
SM = Server Message
|
||||
"""
|
||||
|
||||
class Type(str, Enum):
|
||||
STATUS = "STATUS"
|
||||
AGENT_EVENT = "AGENT_EVENT"
|
||||
ERROR = "ERROR"
|
||||
GRACEFUL_DISCONNECT = "GRACEFUL_DISCONNECT"
|
||||
|
||||
class Status(BaseModel):
|
||||
"""
|
||||
Отправляется сервером при открытии соединения с клиентом.
|
||||
Будет дополнен информацией о готовности агента принимать сообщения.
|
||||
"""
|
||||
type: Literal[SM.Type.STATUS]
|
||||
|
||||
class AgentEventType(str, Enum):
|
||||
TEXT_CHUNK = "TEXT_CHUNK"
|
||||
END = "END"
|
||||
|
||||
class AgentEvent(BaseModel):
|
||||
"""
|
||||
Базовый класс для ивентов, которые стримит агент во время генерации ответа.
|
||||
Конкретный класс для ивента определяется по ``subtype``.
|
||||
"""
|
||||
type: Literal[SM.Type.AGENT_EVENT]
|
||||
subtype: SM.AgentEventType
|
||||
|
||||
class EventTextChunk(AgentEvent):
|
||||
"""
|
||||
Чанк текста ответа агента.
|
||||
"""
|
||||
subtype: Literal[SM.AgentEventType.TEXT_CHUNK]
|
||||
text: str
|
||||
|
||||
class EventEnd(AgentEvent):
|
||||
"""
|
||||
Агент закончил генерацию ответа.
|
||||
"""
|
||||
subtype: Literal[SM.AgentEventType.END]
|
||||
tokens_used: int
|
||||
|
||||
class Error(BaseModel):
|
||||
"""
|
||||
Неопределенная ошибка в работе агента.
|
||||
"""
|
||||
type: Literal[SM.Type.ERROR]
|
||||
code: str
|
||||
details: str
|
||||
|
||||
class GracefulDisconnect(BaseModel):
|
||||
"""
|
||||
Отправляется перед завершением работы контейнера с агентом. Например, при долгом бездействии.
|
||||
Нужно, чтобы отделять обрыв соединения из-за ошибки с необходимостью повторного подключения.
|
||||
Приход этого сообщения означает, что агент осознанно завершает работу с клиентом по какой-то причине.
|
||||
Для дальнейшего взаимодействия нужно снова обратиться к мастеру.
|
||||
"""
|
||||
type: Literal[SM.Type.GRACEFUL_DISCONNECT]
|
||||
|
||||
|
||||
AgentEventUnion = Annotated[
|
||||
Union[
|
||||
SM.EventTextChunk,
|
||||
SM.EventEnd,
|
||||
],
|
||||
Field(discriminator="subtype")
|
||||
]
|
||||
|
||||
|
||||
ServerMessage = Annotated[
|
||||
Union[SM.Status, AgentEventUnion, SM.Error, SM.GracefulDisconnect],
|
||||
Field(discriminator="type")
|
||||
]
|
||||
"""
|
||||
Объединяет все типы исходящих сообщений в одно для удобной автоматической десериализации.\n
|
||||
Pydantic сам определит нужный тип в зависимости от поля ``type``.\n
|
||||
Использование:\n
|
||||
msg = ServerMessage.model_validate_json(json)
|
||||
"""
|
||||
72
lambda_agent_api/server.py
Normal file
72
lambda_agent_api/server.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from enum import Enum
|
||||
from typing import Literal, Annotated, Union
|
||||
|
||||
|
||||
__all__ = ['EServerMessage', 'MsgStatus', 'MsgError', 'MsgEventTextChunk', 'MsgEventEnd', 'AgentEventUnion', 'ServerMessage']
|
||||
|
||||
|
||||
class EServerMessage(str, Enum):
|
||||
STATUS = "STATUS"
|
||||
ERROR = "ERROR"
|
||||
GRACEFUL_DISCONNECT = "GRACEFUL_DISCONNECT"
|
||||
AGENT_EVENT_TEXT_CHUNK = "AGENT_EVENT_TEXT_CHUNK"
|
||||
AGENT_EVENT_END = "AGENT_EVENT_END"
|
||||
|
||||
|
||||
class MsgStatus(BaseModel):
|
||||
"""
|
||||
Отправляется сервером при открытии соединения с клиентом.
|
||||
Будет дополнен информацией о готовности агента принимать сообщения.
|
||||
"""
|
||||
type: Literal[EServerMessage.STATUS] = EServerMessage.STATUS
|
||||
|
||||
|
||||
class MsgEventTextChunk(BaseModel):
|
||||
"""
|
||||
Чанк текста ответа агента.
|
||||
"""
|
||||
type: Literal[EServerMessage.AGENT_EVENT_TEXT_CHUNK] = EServerMessage.AGENT_EVENT_TEXT_CHUNK
|
||||
text: str
|
||||
|
||||
|
||||
class MsgEventEnd(BaseModel):
|
||||
"""
|
||||
Агент закончил генерацию ответа.
|
||||
"""
|
||||
type: Literal[EServerMessage.AGENT_EVENT_END] = EServerMessage.AGENT_EVENT_END
|
||||
tokens_used: int
|
||||
|
||||
|
||||
class MsgError(BaseModel):
|
||||
"""
|
||||
Неопределенная ошибка в работе агента.
|
||||
"""
|
||||
type: Literal[EServerMessage.ERROR] = EServerMessage.ERROR
|
||||
code: str
|
||||
details: str
|
||||
|
||||
|
||||
class MsgGracefulDisconnect(BaseModel):
|
||||
"""
|
||||
Отправляется перед завершением работы контейнера с агентом. Например, при долгом бездействии.
|
||||
Нужно, чтобы отделять обрыв соединения из-за ошибки с необходимостью повторного подключения.
|
||||
Приход этого сообщения означает, что агент осознанно завершает работу с клиентом по какой-то причине.
|
||||
Для дальнейшего взаимодействия нужно снова обратиться к мастеру.
|
||||
"""
|
||||
type: Literal[EServerMessage.GRACEFUL_DISCONNECT] = EServerMessage.GRACEFUL_DISCONNECT
|
||||
|
||||
|
||||
AgentEventUnion = Union[MsgEventTextChunk, MsgEventEnd]
|
||||
|
||||
|
||||
ServerMessage = TypeAdapter(Annotated[
|
||||
Union[MsgStatus, MsgEventTextChunk, MsgEventEnd, MsgError, MsgGracefulDisconnect],
|
||||
Field(discriminator="type")
|
||||
])
|
||||
"""
|
||||
Объединяет все типы исходящих сообщений в одно для удобной автоматической десериализации.\n
|
||||
Pydantic сам определит нужный тип в зависимости от поля ``type``.\n
|
||||
Использование:\n
|
||||
msg = ServerMessage.model_validate_json(json)
|
||||
"""
|
||||
Loading…
Add table
Add a link
Reference in a new issue