fix: use direct agent api per request
This commit is contained in:
parent
7d270d3d31
commit
7d58dd1caf
14 changed files with 285 additions and 400 deletions
83
sdk/real.py
83
sdk/real.py
|
|
@ -4,9 +4,6 @@ import asyncio
|
|||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk
|
||||
|
||||
from sdk.agent_api_wrapper import AgentApiWrapper
|
||||
from sdk.interface import (
|
||||
Attachment,
|
||||
MessageChunk,
|
||||
|
|
@ -17,37 +14,32 @@ from sdk.interface import (
|
|||
UserSettings,
|
||||
)
|
||||
from sdk.prototype_state import PrototypeStateStore
|
||||
from sdk.upstream_agent_api import AgentApi, MsgEventSendFile, MsgEventTextChunk
|
||||
|
||||
|
||||
class RealPlatformClient(PlatformClient):
|
||||
def __init__(
|
||||
self,
|
||||
agent_api: AgentApiWrapper,
|
||||
agent_id: str,
|
||||
agent_base_url: str,
|
||||
prototype_state: PrototypeStateStore,
|
||||
platform: str = "matrix",
|
||||
agent_api_cls=AgentApi,
|
||||
) -> None:
|
||||
self._agent_api = agent_api
|
||||
self._agent_id = agent_id
|
||||
self._agent_base_url = agent_base_url
|
||||
self._agent_api_cls = agent_api_cls
|
||||
self._prototype_state = prototype_state
|
||||
self._platform = platform
|
||||
self._chat_apis: dict[str, AgentApiWrapper] = {}
|
||||
self._chat_api_lock = asyncio.Lock()
|
||||
self._chat_send_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
@property
|
||||
def agent_api(self) -> AgentApiWrapper:
|
||||
return self._agent_api
|
||||
def agent_id(self) -> str:
|
||||
return self._agent_id
|
||||
|
||||
async def _get_chat_api(self, chat_id: str):
|
||||
chat_key = str(chat_id)
|
||||
chat_api = self._chat_apis.get(chat_key)
|
||||
if chat_api is None:
|
||||
async with self._chat_api_lock:
|
||||
chat_api = self._chat_apis.get(chat_key)
|
||||
if chat_api is None:
|
||||
chat_api = self._agent_api.for_chat(chat_key)
|
||||
await chat_api.connect()
|
||||
self._chat_apis[chat_key] = chat_api
|
||||
return chat_api
|
||||
@property
|
||||
def agent_base_url(self) -> str:
|
||||
return self._agent_base_url
|
||||
|
||||
def _get_chat_send_lock(self, chat_id: str) -> asyncio.Lock:
|
||||
chat_key = str(chat_id)
|
||||
|
|
@ -82,9 +74,9 @@ class RealPlatformClient(PlatformClient):
|
|||
|
||||
lock = self._get_chat_send_lock(chat_id)
|
||||
async with lock:
|
||||
chat_api = await self._get_chat_api(chat_id)
|
||||
|
||||
chat_api = self._build_chat_api(chat_id)
|
||||
try:
|
||||
await chat_api.connect()
|
||||
async for event in self._stream_agent_events(
|
||||
chat_api, text, attachments=attachments
|
||||
):
|
||||
|
|
@ -96,8 +88,9 @@ class RealPlatformClient(PlatformClient):
|
|||
if attachment is not None:
|
||||
sent_attachments.append(attachment)
|
||||
except Exception as exc:
|
||||
await self._handle_chat_api_failure(chat_id, exc)
|
||||
|
||||
raise self._to_platform_error(exc) from exc
|
||||
finally:
|
||||
await self._close_chat_api(chat_api)
|
||||
await self._prototype_state.set_last_tokens_used(str(chat_id), 0)
|
||||
|
||||
response_kwargs = {
|
||||
|
|
@ -118,8 +111,9 @@ class RealPlatformClient(PlatformClient):
|
|||
) -> AsyncIterator[MessageChunk]:
|
||||
lock = self._get_chat_send_lock(chat_id)
|
||||
async with lock:
|
||||
chat_api = await self._get_chat_api(chat_id)
|
||||
chat_api = self._build_chat_api(chat_id)
|
||||
try:
|
||||
await chat_api.connect()
|
||||
async for event in self._stream_agent_events(
|
||||
chat_api, text, attachments=attachments
|
||||
):
|
||||
|
|
@ -132,7 +126,9 @@ class RealPlatformClient(PlatformClient):
|
|||
elif isinstance(event, MsgEventSendFile):
|
||||
continue
|
||||
except Exception as exc:
|
||||
await self._handle_chat_api_failure(chat_id, exc)
|
||||
raise self._to_platform_error(exc) from exc
|
||||
finally:
|
||||
await self._close_chat_api(chat_api)
|
||||
await self._prototype_state.set_last_tokens_used(str(chat_id), 0)
|
||||
yield MessageChunk(
|
||||
message_id=user_id,
|
||||
|
|
@ -148,20 +144,9 @@ class RealPlatformClient(PlatformClient):
|
|||
await self._prototype_state.update_settings(user_id, action)
|
||||
|
||||
async def disconnect_chat(self, chat_id: str) -> None:
|
||||
chat_key = str(chat_id)
|
||||
chat_api = self._chat_apis.pop(chat_key, None)
|
||||
self._chat_send_locks.pop(chat_key, None)
|
||||
if chat_api is not None:
|
||||
close = getattr(chat_api, "close", None)
|
||||
if callable(close):
|
||||
await close()
|
||||
self._chat_send_locks.pop(str(chat_id), None)
|
||||
|
||||
async def close(self) -> None:
|
||||
for chat_api in list(self._chat_apis.values()):
|
||||
close = getattr(chat_api, "close", None)
|
||||
if callable(close):
|
||||
await close()
|
||||
self._chat_apis.clear()
|
||||
self._chat_send_locks.clear()
|
||||
|
||||
async def _stream_agent_events(
|
||||
|
|
@ -175,10 +160,26 @@ class RealPlatformClient(PlatformClient):
|
|||
async for event in event_stream:
|
||||
yield event
|
||||
|
||||
async def _handle_chat_api_failure(self, chat_id: str, exc: Exception) -> None:
|
||||
await self.disconnect_chat(chat_id)
|
||||
def _build_chat_api(self, chat_id: str):
|
||||
return self._agent_api_cls(
|
||||
agent_id=self._agent_id,
|
||||
base_url=self._agent_base_url,
|
||||
chat_id=str(chat_id),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _close_chat_api(chat_api) -> None:
|
||||
close = getattr(chat_api, "close", None)
|
||||
if callable(close):
|
||||
try:
|
||||
await close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _to_platform_error(exc: Exception) -> PlatformError:
|
||||
code = getattr(exc, "code", None) or "PLATFORM_CONNECTION_ERROR"
|
||||
raise PlatformError(str(exc), code=code) from exc
|
||||
return PlatformError(str(exc), code=code)
|
||||
|
||||
@staticmethod
|
||||
def _attachment_paths(attachments: list[Attachment] | None) -> list[str]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue