diff --git a/src/agent/__init__.py b/src/agent/__init__.py index ff3ec8c..ccfc9bd 100644 --- a/src/agent/__init__.py +++ b/src/agent/__init__.py @@ -1,3 +1,3 @@ -from src.agent.service import AgentService, get_agent_service +from src.agent.service import AgentService, AgentChat -__all__ = ["AgentService", "get_agent_service"] +__all__ = ["AgentService", "AgentChat"] diff --git a/src/agent/service.py b/src/agent/service.py index 4cdae6d..aa9b45f 100644 --- a/src/agent/service.py +++ b/src/agent/service.py @@ -1,4 +1,5 @@ -from typing import AsyncIterator +from typing import AsyncIterator, AsyncContextManager, Self +from abc import ABC, abstractmethod from src.agent.base import create_agent from lambda_agent_api.server import ( @@ -6,6 +7,19 @@ from lambda_agent_api.server import ( MsgEventToolResult, MsgEventEnd ) + +class ChatInUseError(Exception): + pass + + +class AgentChat(AsyncContextManager[Self]): + chat_id: int + + @abstractmethod + def astream(self, text: str) -> AsyncIterator[AgentEventUnion]: + ... + + class AgentService: _instance = None @@ -15,7 +29,38 @@ class AgentService: cls._instance._agent = create_agent() return cls._instance - async def astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]: + class __AgentChat(AgentChat): + __locks: set[int] = set() + + def __init__(self, service: AgentService, chat_id: int) -> None: + self.__chat_id = chat_id + self.__service = service + + # noinspection PyProtocol + @property + def chat_id(self) -> int: + return self.__chat_id + + async def __aenter__(self): + if self.__chat_id in self.__locks: + raise ChatInUseError() + + self.__locks.add(self.__chat_id) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.__locks.remove(self.__chat_id) + + async def astream(self, text: str) -> AsyncIterator[AgentEventUnion]: + if not self.__chat_id in self.__locks: + raise RuntimeError("Chat must be used in `with` statement") + + return self.__service._AgentService__astream(self.__chat_id, text) + + def chat(self, chat_id: int) -> AgentChat: + return self.__AgentChat(self, chat_id) + + async def __astream(self, chat_id: int, text: str) -> AsyncIterator[AgentEventUnion]: config = {"configurable": {"thread_id": chat_id}} # Используем astream_events для перехвата детальных событий (инструменты, чанки и т.д.) @@ -51,7 +96,3 @@ class AgentService: # 3. В конце генерации отправляем событие завершения yield MsgEventEnd(tokens_used=0) # потом заменить на метадату - - -def get_agent_service() -> AgentService: - return AgentService() diff --git a/src/api/dependencies.py b/src/api/dependencies.py new file mode 100644 index 0000000..01f2cf1 --- /dev/null +++ b/src/api/dependencies.py @@ -0,0 +1,14 @@ +from typing import Annotated, AsyncGenerator +from fastapi import Depends + +from src.agent import AgentService, AgentChat + + +def get_agent_service() -> AgentService: + return AgentService() + + +async def get_chat(service: Annotated[AgentService, Depends(get_agent_service)], + chat_id: int) -> AsyncGenerator[AgentChat]: + async with service.chat(chat_id) as chat: + yield chat diff --git a/src/api/external.py b/src/api/external.py index 9a84391..4985dee 100644 --- a/src/api/external.py +++ b/src/api/external.py @@ -1,3 +1,5 @@ +from typing import Annotated + from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends from lambda_agent_api.server import ( @@ -8,7 +10,8 @@ from lambda_agent_api.server import ( ) from lambda_agent_api.client import ClientMessage, MsgUserMessage -from src.agent import get_agent_service, AgentService +from src.agent import AgentChat +from src.api.dependencies import get_chat router = APIRouter() @@ -17,8 +20,7 @@ router = APIRouter() @router.websocket("/agent_ws/{chat_id}/") async def websocket_endpoint( ws: WebSocket, - chat_id: int, - agent_service: AgentService = Depends(get_agent_service), + chat: Annotated[AgentChat, Depends(get_chat)], ): await ws.accept() await ws.send_text(MsgStatus().model_dump_json()) @@ -27,7 +29,7 @@ async def websocket_endpoint( while True: raw = await ws.receive_text() msg = ClientMessage.validate_json(raw) - await process_message(ws, chat_id, msg, agent_service) + await process_message(ws, chat, msg) except WebSocketDisconnect: pass @@ -37,9 +39,9 @@ async def websocket_endpoint( ) -async def process_message(ws: WebSocket, chat_id: int, msg, agent_service: AgentService): +async def process_message(ws: WebSocket, chat: AgentChat, msg): match msg: case MsgUserMessage(): - async for chunk in agent_service.astream(chat_id, msg.text): + async for chunk in chat.astream(msg.text): await ws.send_text(chunk.model_dump_json()) await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json()) diff --git a/src/main.py b/src/main.py index c11618a..02418a8 100644 --- a/src/main.py +++ b/src/main.py @@ -2,7 +2,7 @@ from contextlib import asynccontextmanager from fastapi import FastAPI -from src.agent import get_agent_service +from src.api.dependencies import get_agent_service from src.api.external import router as ws_router