diff --git a/tests/manual.py b/tests/manual.py index d59754c..451b8a2 100644 --- a/tests/manual.py +++ b/tests/manual.py @@ -2,7 +2,7 @@ import asyncio import traceback from lambda_agent_api.agent_api import AgentApi -from lambda_agent_api.server import MsgEventTextChunk +from lambda_agent_api.server import MsgEventTextChunk, MsgEventToolCallChunk, MsgEventToolResult def my_callback(message): @@ -17,9 +17,21 @@ async def main(): try: prompt = await asyncio.get_event_loop().run_in_executor(None, input, ">>> ") print("Agent: ", end="") + is_tool = False async for chunk in api.send_message(prompt): - if isinstance(chunk, MsgEventTextChunk): - print(chunk.text, end="", flush=True) + match chunk: + case MsgEventTextChunk(): + is_tool = True + print(chunk.text, end="", flush=True) + case MsgEventToolCallChunk(): + if not is_tool: + print(f"\n\n### TOOL CALL: ({chunk.tool_name}) ", end="", flush=True) + is_tool = True + print(chunk.args_chunk, end="", flush=True) + case MsgEventToolResult(): + is_tool = False + print(f"\nResult: {chunk.result}\n\n", end="", flush=True) + print("\n") except KeyboardInterrupt: break