сохранение состояние между вызовами
This commit is contained in:
parent
635986ba70
commit
d4d775305f
5 changed files with 63 additions and 36 deletions
|
|
@ -0,0 +1,3 @@
|
||||||
|
from src.agent.service import AgentService, get_agent_service
|
||||||
|
|
||||||
|
__all__ = ["AgentService", "get_agent_service"]
|
||||||
|
|
@ -1,26 +1,19 @@
|
||||||
import os
|
import os
|
||||||
from deepagents import create_deep_agent
|
from deepagents import create_deep_agent
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
|
|
||||||
def create_agent():
|
def create_agent():
|
||||||
api_url = os.environ.get("PROVIDER_URL")
|
|
||||||
api_key = os.environ.get("PROVIDER_API_KEY")
|
|
||||||
model_name = os.environ.get("PROVIDER_MODEL")
|
|
||||||
|
|
||||||
if None in (api_url, api_key, model_name):
|
|
||||||
raise RuntimeError("PROVIDER_URL, PROVIDER_API_KEY or PROVIDER_MODEL_NAME is not configured")
|
|
||||||
|
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
model=model_name,
|
model=os.environ["PROVIDER_MODEL"],
|
||||||
base_url=api_url,
|
base_url=os.environ["PROVIDER_URL"],
|
||||||
api_key=api_key,
|
api_key=os.environ["PROVIDER_API_KEY"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return create_deep_agent(
|
return create_deep_agent(
|
||||||
model=model,
|
model=model,
|
||||||
system_prompt="You are a helpful assistant.",
|
system_prompt="You are a helpful assistant.",
|
||||||
|
checkpointer=MemorySaver(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
agent = create_agent()
|
|
||||||
|
|
|
||||||
34
src/agent/service.py
Normal file
34
src/agent/service.py
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
from src.agent.base import create_agent
|
||||||
|
|
||||||
|
|
||||||
|
class AgentService:
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._agent = create_agent()
|
||||||
|
cls._instance._thread_id = "default"
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
async def astream(self, text: str) -> AsyncIterator[str]:
|
||||||
|
config = {"configurable": {"thread_id": self._thread_id}}
|
||||||
|
|
||||||
|
async for event in self._agent.astream(
|
||||||
|
{"messages": [{"role": "user", "content": text}]},
|
||||||
|
config=config,
|
||||||
|
):
|
||||||
|
messages = event.get("messages") or event.get("model", {}).get(
|
||||||
|
"messages", []
|
||||||
|
)
|
||||||
|
if messages:
|
||||||
|
last_msg = messages[-1]
|
||||||
|
content = getattr(last_msg, "content", None)
|
||||||
|
if isinstance(content, str) and content.strip():
|
||||||
|
yield content
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_service() -> AgentService:
|
||||||
|
return AgentService()
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
||||||
|
|
||||||
from lambda_agent_api.server import (
|
from lambda_agent_api.server import (
|
||||||
MsgStatus,
|
MsgStatus,
|
||||||
|
|
@ -8,14 +8,17 @@ from lambda_agent_api.server import (
|
||||||
)
|
)
|
||||||
from lambda_agent_api.client import ClientMessage, MsgUserMessage
|
from lambda_agent_api.client import ClientMessage, MsgUserMessage
|
||||||
|
|
||||||
from src.agent.base import agent
|
from src.agent import get_agent_service, AgentService
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/agent_ws/")
|
@router.websocket("/agent_ws/")
|
||||||
async def websocket_endpoint(ws: WebSocket):
|
async def websocket_endpoint(
|
||||||
|
ws: WebSocket,
|
||||||
|
agent_service: AgentService = Depends(get_agent_service),
|
||||||
|
):
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
await ws.send_text(MsgStatus().model_dump_json())
|
await ws.send_text(MsgStatus().model_dump_json())
|
||||||
|
|
||||||
|
|
@ -23,7 +26,7 @@ async def websocket_endpoint(ws: WebSocket):
|
||||||
while True:
|
while True:
|
||||||
raw = await ws.receive_text()
|
raw = await ws.receive_text()
|
||||||
msg = ClientMessage.validate_json(raw)
|
msg = ClientMessage.validate_json(raw)
|
||||||
await process_message(ws, msg)
|
await process_message(ws, msg, agent_service)
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
pass
|
pass
|
||||||
|
|
@ -33,23 +36,9 @@ async def websocket_endpoint(ws: WebSocket):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def process_message(
|
async def process_message(ws: WebSocket, msg, agent_service: AgentService):
|
||||||
ws: WebSocket, msg
|
|
||||||
): # msg должно быть ClientMessage (аннотация не работает из-за TypeAdapter)
|
|
||||||
match msg:
|
match msg:
|
||||||
case MsgUserMessage():
|
case MsgUserMessage():
|
||||||
await handle_user_message(ws, msg.text)
|
async for chunk in agent_service.astream(msg.text):
|
||||||
|
await ws.send_text(MsgEventTextChunk(text=chunk).model_dump_json())
|
||||||
|
await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json())
|
||||||
async def handle_user_message(ws: WebSocket, text: str):
|
|
||||||
tokens_used = 0
|
|
||||||
|
|
||||||
async for event in agent.astream({"messages": [{"role": "user", "content": text}]}):
|
|
||||||
messages = event.get("messages") or event.get("model", {}).get("messages", [])
|
|
||||||
if messages:
|
|
||||||
last_msg = messages[-1]
|
|
||||||
content = getattr(last_msg, "content", None)
|
|
||||||
if isinstance(content, str) and content.strip():
|
|
||||||
await ws.send_text(MsgEventTextChunk(text=content.strip()).model_dump_json())
|
|
||||||
|
|
||||||
await ws.send_text(MsgEventEnd(tokens_used=tokens_used).model_dump_json())
|
|
||||||
|
|
|
||||||
12
src/main.py
12
src/main.py
|
|
@ -1,8 +1,16 @@
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from src.agent import get_agent_service
|
||||||
from src.api.external import router as ws_router
|
from src.api.external import router as ws_router
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
@asynccontextmanager
|
||||||
app.include_router(ws_router)
|
async def lifespan(app: FastAPI):
|
||||||
|
get_agent_service()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.include_router(ws_router)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue