diff --git a/src/agent/__init__.py b/src/agent/__init__.py index e69de29..ff3ec8c 100644 --- a/src/agent/__init__.py +++ b/src/agent/__init__.py @@ -0,0 +1,3 @@ +from src.agent.service import AgentService, get_agent_service + +__all__ = ["AgentService", "get_agent_service"] diff --git a/src/agent/base.py b/src/agent/base.py index c78f76f..f635ba8 100644 --- a/src/agent/base.py +++ b/src/agent/base.py @@ -1,26 +1,19 @@ import os from deepagents import create_deep_agent from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver def create_agent(): - api_url = os.environ.get("PROVIDER_URL") - api_key = os.environ.get("PROVIDER_API_KEY") - model_name = os.environ.get("PROVIDER_MODEL") - - if None in (api_url, api_key, model_name): - raise RuntimeError("PROVIDER_URL, PROVIDER_API_KEY or PROVIDER_MODEL_NAME is not configured") - model = ChatOpenAI( - model=model_name, - base_url=api_url, - api_key=api_key, + model=os.environ["PROVIDER_MODEL"], + base_url=os.environ["PROVIDER_URL"], + api_key=os.environ["PROVIDER_API_KEY"], ) return create_deep_agent( model=model, system_prompt="You are a helpful assistant.", + checkpointer=MemorySaver(), ) - -agent = create_agent() diff --git a/src/agent/service.py b/src/agent/service.py new file mode 100644 index 0000000..e0d2cc3 --- /dev/null +++ b/src/agent/service.py @@ -0,0 +1,34 @@ +from typing import AsyncIterator + +from src.agent.base import create_agent + + +class AgentService: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._agent = create_agent() + cls._instance._thread_id = "default" + return cls._instance + + async def astream(self, text: str) -> AsyncIterator[str]: + config = {"configurable": {"thread_id": self._thread_id}} + + async for event in self._agent.astream( + {"messages": [{"role": "user", "content": text}]}, + config=config, + ): + messages = event.get("messages") or event.get("model", {}).get( + "messages", [] + ) + if messages: + last_msg = messages[-1] + content = getattr(last_msg, "content", None) + if isinstance(content, str) and content.strip(): + yield content + + +def get_agent_service() -> AgentService: + return AgentService() diff --git a/src/api/external.py b/src/api/external.py index 8c2579d..219d5f7 100644 --- a/src/api/external.py +++ b/src/api/external.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends from lambda_agent_api.server import ( MsgStatus, @@ -8,14 +8,17 @@ from lambda_agent_api.server import ( ) from lambda_agent_api.client import ClientMessage, MsgUserMessage -from src.agent.base import agent +from src.agent import get_agent_service, AgentService router = APIRouter() @router.websocket("/agent_ws/") -async def websocket_endpoint(ws: WebSocket): +async def websocket_endpoint( + ws: WebSocket, + agent_service: AgentService = Depends(get_agent_service), +): await ws.accept() await ws.send_text(MsgStatus().model_dump_json()) @@ -23,7 +26,7 @@ async def websocket_endpoint(ws: WebSocket): while True: raw = await ws.receive_text() msg = ClientMessage.validate_json(raw) - await process_message(ws, msg) + await process_message(ws, msg, agent_service) except WebSocketDisconnect: pass @@ -33,23 +36,9 @@ async def websocket_endpoint(ws: WebSocket): ) -async def process_message( - ws: WebSocket, msg -): # msg должно быть ClientMessage (аннотация не работает из-за TypeAdapter) +async def process_message(ws: WebSocket, msg, agent_service: AgentService): match msg: case MsgUserMessage(): - await handle_user_message(ws, msg.text) - - -async def handle_user_message(ws: WebSocket, text: str): - tokens_used = 0 - - async for event in agent.astream({"messages": [{"role": "user", "content": text}]}): - messages = event.get("messages") or event.get("model", {}).get("messages", []) - if messages: - last_msg = messages[-1] - content = getattr(last_msg, "content", None) - if isinstance(content, str) and content.strip(): - await ws.send_text(MsgEventTextChunk(text=content.strip()).model_dump_json()) - - await ws.send_text(MsgEventEnd(tokens_used=tokens_used).model_dump_json()) + async for chunk in agent_service.astream(msg.text): + await ws.send_text(MsgEventTextChunk(text=chunk).model_dump_json()) + await ws.send_text(MsgEventEnd(tokens_used=0).model_dump_json()) diff --git a/src/main.py b/src/main.py index ade53dc..c11618a 100644 --- a/src/main.py +++ b/src/main.py @@ -1,8 +1,16 @@ +from contextlib import asynccontextmanager + from fastapi import FastAPI +from src.agent import get_agent_service from src.api.external import router as ws_router -app = FastAPI() -app.include_router(ws_router) +@asynccontextmanager +async def lifespan(app: FastAPI): + get_agent_service() + yield + +app = FastAPI(lifespan=lifespan) +app.include_router(ws_router)