fix: use direct agent api per request
This commit is contained in:
parent
7d270d3d31
commit
7d58dd1caf
14 changed files with 285 additions and 400 deletions
|
|
@ -1,48 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlsplit, urlunsplit
|
||||
|
||||
_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api"
|
||||
if str(_api_root) not in sys.path:
|
||||
sys.path.insert(0, str(_api_root))
|
||||
|
||||
from lambda_agent_api.agent_api import AgentApi # noqa: E402
|
||||
|
||||
|
||||
class AgentApiWrapper(AgentApi):
|
||||
"""Thin construction/factory shim over the pinned upstream AgentApi."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
base_url: str,
|
||||
*,
|
||||
chat_id: int | str = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self._base_url = self._normalize_base_url(base_url)
|
||||
self._init_kwargs = dict(kwargs)
|
||||
self.chat_id = chat_id
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
base_url=self._base_url,
|
||||
chat_id=chat_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
parsed = urlsplit(base_url)
|
||||
path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?/?$", "", parsed.path.rstrip("/"))
|
||||
return urlunsplit((parsed.scheme, parsed.netloc, path, "", ""))
|
||||
|
||||
def for_chat(self, chat_id: int | str) -> AgentApiWrapper:
|
||||
return type(self)(
|
||||
agent_id=self.id,
|
||||
base_url=self._base_url,
|
||||
chat_id=chat_id,
|
||||
**self._init_kwargs,
|
||||
)
|
||||
|
|
@ -1 +1 @@
|
|||
"""Compatibility stub: AgentSessionClient was replaced by AgentApiWrapper in Phase 4."""
|
||||
"""Compatibility stub: AgentSessionClient was replaced by direct AgentApi usage in Phase 4."""
|
||||
|
|
|
|||
83
sdk/real.py
83
sdk/real.py
|
|
@ -4,9 +4,6 @@ import asyncio
|
|||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk
|
||||
|
||||
from sdk.agent_api_wrapper import AgentApiWrapper
|
||||
from sdk.interface import (
|
||||
Attachment,
|
||||
MessageChunk,
|
||||
|
|
@ -17,37 +14,32 @@ from sdk.interface import (
|
|||
UserSettings,
|
||||
)
|
||||
from sdk.prototype_state import PrototypeStateStore
|
||||
from sdk.upstream_agent_api import AgentApi, MsgEventSendFile, MsgEventTextChunk
|
||||
|
||||
|
||||
class RealPlatformClient(PlatformClient):
|
||||
def __init__(
|
||||
self,
|
||||
agent_api: AgentApiWrapper,
|
||||
agent_id: str,
|
||||
agent_base_url: str,
|
||||
prototype_state: PrototypeStateStore,
|
||||
platform: str = "matrix",
|
||||
agent_api_cls=AgentApi,
|
||||
) -> None:
|
||||
self._agent_api = agent_api
|
||||
self._agent_id = agent_id
|
||||
self._agent_base_url = agent_base_url
|
||||
self._agent_api_cls = agent_api_cls
|
||||
self._prototype_state = prototype_state
|
||||
self._platform = platform
|
||||
self._chat_apis: dict[str, AgentApiWrapper] = {}
|
||||
self._chat_api_lock = asyncio.Lock()
|
||||
self._chat_send_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
@property
|
||||
def agent_api(self) -> AgentApiWrapper:
|
||||
return self._agent_api
|
||||
def agent_id(self) -> str:
|
||||
return self._agent_id
|
||||
|
||||
async def _get_chat_api(self, chat_id: str):
|
||||
chat_key = str(chat_id)
|
||||
chat_api = self._chat_apis.get(chat_key)
|
||||
if chat_api is None:
|
||||
async with self._chat_api_lock:
|
||||
chat_api = self._chat_apis.get(chat_key)
|
||||
if chat_api is None:
|
||||
chat_api = self._agent_api.for_chat(chat_key)
|
||||
await chat_api.connect()
|
||||
self._chat_apis[chat_key] = chat_api
|
||||
return chat_api
|
||||
@property
|
||||
def agent_base_url(self) -> str:
|
||||
return self._agent_base_url
|
||||
|
||||
def _get_chat_send_lock(self, chat_id: str) -> asyncio.Lock:
|
||||
chat_key = str(chat_id)
|
||||
|
|
@ -82,9 +74,9 @@ class RealPlatformClient(PlatformClient):
|
|||
|
||||
lock = self._get_chat_send_lock(chat_id)
|
||||
async with lock:
|
||||
chat_api = await self._get_chat_api(chat_id)
|
||||
|
||||
chat_api = self._build_chat_api(chat_id)
|
||||
try:
|
||||
await chat_api.connect()
|
||||
async for event in self._stream_agent_events(
|
||||
chat_api, text, attachments=attachments
|
||||
):
|
||||
|
|
@ -96,8 +88,9 @@ class RealPlatformClient(PlatformClient):
|
|||
if attachment is not None:
|
||||
sent_attachments.append(attachment)
|
||||
except Exception as exc:
|
||||
await self._handle_chat_api_failure(chat_id, exc)
|
||||
|
||||
raise self._to_platform_error(exc) from exc
|
||||
finally:
|
||||
await self._close_chat_api(chat_api)
|
||||
await self._prototype_state.set_last_tokens_used(str(chat_id), 0)
|
||||
|
||||
response_kwargs = {
|
||||
|
|
@ -118,8 +111,9 @@ class RealPlatformClient(PlatformClient):
|
|||
) -> AsyncIterator[MessageChunk]:
|
||||
lock = self._get_chat_send_lock(chat_id)
|
||||
async with lock:
|
||||
chat_api = await self._get_chat_api(chat_id)
|
||||
chat_api = self._build_chat_api(chat_id)
|
||||
try:
|
||||
await chat_api.connect()
|
||||
async for event in self._stream_agent_events(
|
||||
chat_api, text, attachments=attachments
|
||||
):
|
||||
|
|
@ -132,7 +126,9 @@ class RealPlatformClient(PlatformClient):
|
|||
elif isinstance(event, MsgEventSendFile):
|
||||
continue
|
||||
except Exception as exc:
|
||||
await self._handle_chat_api_failure(chat_id, exc)
|
||||
raise self._to_platform_error(exc) from exc
|
||||
finally:
|
||||
await self._close_chat_api(chat_api)
|
||||
await self._prototype_state.set_last_tokens_used(str(chat_id), 0)
|
||||
yield MessageChunk(
|
||||
message_id=user_id,
|
||||
|
|
@ -148,20 +144,9 @@ class RealPlatformClient(PlatformClient):
|
|||
await self._prototype_state.update_settings(user_id, action)
|
||||
|
||||
async def disconnect_chat(self, chat_id: str) -> None:
|
||||
chat_key = str(chat_id)
|
||||
chat_api = self._chat_apis.pop(chat_key, None)
|
||||
self._chat_send_locks.pop(chat_key, None)
|
||||
if chat_api is not None:
|
||||
close = getattr(chat_api, "close", None)
|
||||
if callable(close):
|
||||
await close()
|
||||
self._chat_send_locks.pop(str(chat_id), None)
|
||||
|
||||
async def close(self) -> None:
|
||||
for chat_api in list(self._chat_apis.values()):
|
||||
close = getattr(chat_api, "close", None)
|
||||
if callable(close):
|
||||
await close()
|
||||
self._chat_apis.clear()
|
||||
self._chat_send_locks.clear()
|
||||
|
||||
async def _stream_agent_events(
|
||||
|
|
@ -175,10 +160,26 @@ class RealPlatformClient(PlatformClient):
|
|||
async for event in event_stream:
|
||||
yield event
|
||||
|
||||
async def _handle_chat_api_failure(self, chat_id: str, exc: Exception) -> None:
|
||||
await self.disconnect_chat(chat_id)
|
||||
def _build_chat_api(self, chat_id: str):
|
||||
return self._agent_api_cls(
|
||||
agent_id=self._agent_id,
|
||||
base_url=self._agent_base_url,
|
||||
chat_id=str(chat_id),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _close_chat_api(chat_api) -> None:
|
||||
close = getattr(chat_api, "close", None)
|
||||
if callable(close):
|
||||
try:
|
||||
await close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _to_platform_error(exc: Exception) -> PlatformError:
|
||||
code = getattr(exc, "code", None) or "PLATFORM_CONNECTION_ERROR"
|
||||
raise PlatformError(str(exc), code=code) from exc
|
||||
return PlatformError(str(exc), code=code)
|
||||
|
||||
@staticmethod
|
||||
def _attachment_paths(attachments: list[Attachment] | None) -> list[str]:
|
||||
|
|
|
|||
19
sdk/upstream_agent_api.py
Normal file
19
sdk/upstream_agent_api.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api"
|
||||
if str(_api_root) not in sys.path:
|
||||
sys.path.insert(0, str(_api_root))
|
||||
|
||||
from lambda_agent_api.agent_api import AgentApi, AgentBusyException, AgentException # noqa: E402
|
||||
from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
"AgentApi",
|
||||
"AgentBusyException",
|
||||
"AgentException",
|
||||
"MsgEventSendFile",
|
||||
"MsgEventTextChunk",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue