сохранение состояние между вызовами

This commit is contained in:
Егор Кандрушин 2026-04-03 15:57:32 +03:00
parent 635986ba70
commit d4d775305f
5 changed files with 63 additions and 36 deletions

View file

@ -0,0 +1,3 @@
from src.agent.service import AgentService, get_agent_service
__all__ = ["AgentService", "get_agent_service"]

View file

@ -1,26 +1,19 @@
import os import os
from deepagents import create_deep_agent from deepagents import create_deep_agent
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
def create_agent(): def create_agent():
api_url = os.environ.get("PROVIDER_URL")
api_key = os.environ.get("PROVIDER_API_KEY")
model_name = os.environ.get("PROVIDER_MODEL")
if None in (api_url, api_key, model_name):
raise RuntimeError("PROVIDER_URL, PROVIDER_API_KEY or PROVIDER_MODEL_NAME is not configured")
model = ChatOpenAI( model = ChatOpenAI(
model=model_name, model=os.environ["PROVIDER_MODEL"],
base_url=api_url, base_url=os.environ["PROVIDER_URL"],
api_key=api_key, api_key=os.environ["PROVIDER_API_KEY"],
) )
return create_deep_agent( return create_deep_agent(
model=model, model=model,
system_prompt="You are a helpful assistant.", system_prompt="You are a helpful assistant.",
checkpointer=MemorySaver(),
) )
agent = create_agent()

34
src/agent/service.py Normal file
View file

@ -0,0 +1,34 @@
from typing import AsyncIterator
from src.agent.base import create_agent
class AgentService:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._agent = create_agent()
cls._instance._thread_id = "default"
return cls._instance
async def astream(self, text: str) -> AsyncIterator[str]:
config = {"configurable": {"thread_id": self._thread_id}}
async for event in self._agent.astream(
{"messages": [{"role": "user", "content": text}]},
config=config,
):
messages = event.get("messages") or event.get("model", {}).get(
"messages", []
)
if messages:
last_msg = messages[-1]
content = getattr(last_msg, "content", None)
if isinstance(content, str) and content.strip():
yield content
def get_agent_service() -> AgentService:
return AgentService()

View file

@ -1,4 +1,4 @@
from fastapi import APIRouter, WebSocket, WebSocketDisconnect from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
from lambda_agent_api.server import ( from lambda_agent_api.server import (
MsgStatus, MsgStatus,
@ -8,14 +8,17 @@ from lambda_agent_api.server import (
) )
from lambda_agent_api.client import ClientMessage, MsgUserMessage from lambda_agent_api.client import ClientMessage, MsgUserMessage
from src.agent.base import agent from src.agent import get_agent_service, AgentService
router = APIRouter() router = APIRouter()
@router.websocket("/agent_ws/") @router.websocket("/agent_ws/")
async def websocket_endpoint(ws: WebSocket): async def websocket_endpoint(
ws: WebSocket,
agent_service: AgentService = Depends(get_agent_service),
):
await ws.accept() await ws.accept()
await ws.send_text(MsgStatus().model_dump_json()) await ws.send_text(MsgStatus().model_dump_json())
@ -23,7 +26,7 @@ async def websocket_endpoint(ws: WebSocket):
while True: while True:
raw = await ws.receive_text() raw = await ws.receive_text()
msg = ClientMessage.validate_json(raw) msg = ClientMessage.validate_json(raw)
await process_message(ws, msg) await process_message(ws, msg, agent_service)
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass
@ -33,23 +36,9 @@ async def websocket_endpoint(ws: WebSocket):
) )
async def process_message( async def process_message(ws: WebSocket, msg, agent_service: AgentService):
ws: WebSocket, msg
): # msg должно быть ClientMessage (аннотация не работает из-за TypeAdapter)
match msg: match msg:
case MsgUserMessage(): case MsgUserMessage():
await handle_user_message(ws, msg.text) async for chunk in agent_service.astream(msg.text):
await ws.send_text(MsgEventTextChunk(text=chunk).model_dump_json())
await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json())
async def handle_user_message(ws: WebSocket, text: str):
tokens_used = 0
async for event in agent.astream({"messages": [{"role": "user", "content": text}]}):
messages = event.get("messages") or event.get("model", {}).get("messages", [])
if messages:
last_msg = messages[-1]
content = getattr(last_msg, "content", None)
if isinstance(content, str) and content.strip():
await ws.send_text(MsgEventTextChunk(text=content.strip()).model_dump_json())
await ws.send_text(MsgEventEnd(tokens_used=tokens_used).model_dump_json())

View file

@ -1,8 +1,16 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from src.agent import get_agent_service
from src.api.external import router as ws_router from src.api.external import router as ws_router
app = FastAPI() @asynccontextmanager
app.include_router(ws_router) async def lifespan(app: FastAPI):
get_agent_service()
yield
app = FastAPI(lifespan=lifespan)
app.include_router(ws_router)