diff --git a/tests/manual.py b/tests/manual.py index d59754c..a39802b 100644 --- a/tests/manual.py +++ b/tests/manual.py @@ -2,24 +2,33 @@ import asyncio import traceback from lambda_agent_api.agent_api import AgentApi -from lambda_agent_api.server import MsgEventTextChunk +from lambda_agent_api.server import MsgEventTextChunk, MsgEventToolCallChunk, MsgEventToolResult -def my_callback(message): - print(f"Callback: {message}") - async def main(): - api = AgentApi("agent-1", "ws://localhost:8000/agent_ws/", callback=my_callback) + api = AgentApi("agent-1", "ws://localhost:8000/agent_ws/") await api.connect() while True: try: prompt = await asyncio.get_event_loop().run_in_executor(None, input, ">>> ") print("Agent: ", end="") + is_tool = False async for chunk in api.send_message(prompt): - if isinstance(chunk, MsgEventTextChunk): - print(chunk.text, end="", flush=True) + match chunk: + case MsgEventTextChunk(): + is_tool = False + print(chunk.text, end="", flush=True) + case MsgEventToolCallChunk(): + if not is_tool: + print(f"\n\n### TOOL CALL: ({chunk.tool_name}) ", end="", flush=True) + is_tool = True + print(chunk.args_chunk, end="", flush=True) + case MsgEventToolResult(): + is_tool = False + print(f"\nResult: {chunk.result}\n\n", end="", flush=True) + print("\n") except KeyboardInterrupt: break