корректные pydantic модели для автоматического определения класса по полю type
This commit is contained in:
parent
b34cbaf677
commit
1e256a545b
11 changed files with 294 additions and 148 deletions
|
|
@ -3,7 +3,8 @@ from typing import Callable, Optional, AsyncIterator
|
|||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
from lambda_agent_api.models import CM, SM, ClientMessage, ServerMessage
|
||||
from lambda_agent_api.server import *
|
||||
from lambda_agent_api.client import *
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -52,8 +53,8 @@ class AgentApi:
|
|||
msg = await asyncio.wait_for(self._ws.receive(), timeout=5.0)
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
status_msg = ServerMessage.model_validate_json(msg.data)
|
||||
if isinstance(status_msg, SM.Status):
|
||||
status_msg = ServerMessage.validate_json(msg.data)
|
||||
if isinstance(status_msg, MsgStatus):
|
||||
logger.info(f"Agent {self.id} is ready")
|
||||
else:
|
||||
raise AgentException(
|
||||
|
|
@ -130,7 +131,7 @@ class AgentApi:
|
|||
except Exception as e:
|
||||
logger.error(f"Error in on_disconnect: {e}")
|
||||
|
||||
async def send_message(self, text: str) -> AsyncIterator[SM.AgentEvent]:
|
||||
async def send_message(self, text: str) -> AsyncIterator[AgentEventUnion]:
|
||||
"""
|
||||
Нативный асинхронный генератор.
|
||||
Не требует отдельного класса ResponseIterator.
|
||||
|
|
@ -151,8 +152,8 @@ class AgentApi:
|
|||
try:
|
||||
self._current_queue = asyncio.Queue()
|
||||
|
||||
message = CM.UserMessage(
|
||||
type=CM.Type.USER_MESSAGE,
|
||||
message = MsgUserMessage(
|
||||
type=EClientMessage.USER_MESSAGE,
|
||||
text=text
|
||||
)
|
||||
|
||||
|
|
@ -168,7 +169,7 @@ class AgentApi:
|
|||
raise chunk
|
||||
|
||||
# Если конец ответа
|
||||
if isinstance(chunk, SM.EventEnd):
|
||||
if isinstance(chunk, MsgEventEnd):
|
||||
# Если вам нужны tokens_used, можно сохранить их в атрибут self
|
||||
break
|
||||
|
||||
|
|
@ -222,10 +223,10 @@ class AgentApi:
|
|||
async for msg in self._ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
outgoing_msg = ServerMessage.model_validate_json(
|
||||
outgoing_msg = ServerMessage.validate_json(
|
||||
msg.data)
|
||||
|
||||
if isinstance(outgoing_msg, SM.AgentEvent):
|
||||
if isinstance(outgoing_msg, MsgEventTextChunk):
|
||||
if self._current_queue:
|
||||
await self._current_queue.put(outgoing_msg)
|
||||
# Если очереди нет (клиент отменил запрос), но токены идут — шлем их в коллбек
|
||||
|
|
@ -235,11 +236,11 @@ class AgentApi:
|
|||
logger.warning(
|
||||
f"[{self.id}] AgentEvent without active request")
|
||||
|
||||
elif isinstance(outgoing_msg, SM.EventEnd):
|
||||
elif isinstance(outgoing_msg, MsgEventEnd):
|
||||
if self._current_queue:
|
||||
await self._current_queue.put(outgoing_msg)
|
||||
|
||||
elif isinstance(outgoing_msg, SM.Error):
|
||||
elif isinstance(outgoing_msg, MsgError):
|
||||
if self.callback:
|
||||
self.callback(outgoing_msg)
|
||||
error = AgentException(
|
||||
|
|
@ -248,7 +249,7 @@ class AgentApi:
|
|||
if self._current_queue:
|
||||
await self._current_queue.put(error)
|
||||
|
||||
elif isinstance(outgoing_msg, SM.GracefulDisconnect):
|
||||
elif isinstance(outgoing_msg, MsgGracefulDisconnect):
|
||||
if self.callback:
|
||||
self.callback(outgoing_msg)
|
||||
logger.info(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue