test: restore vllm integration coverage and add dict-args regression
Restore the existing vLLM integration test module that was accidentally replaced during development and add a focused agent-loop regression test for dict tool-call arguments from OpenAI-compatible local backends.
This commit is contained in:
parent
93a0c0cddd
commit
5847c180c6
2 changed files with 413 additions and 46 deletions
|
|
@ -1,64 +1,359 @@
|
||||||
|
"""Integration tests for HermesAgentLoop with a local vLLM server.
|
||||||
|
|
||||||
|
Tests the full Phase 2 flow: ManagedServer + tool calling with a real
|
||||||
|
vLLM backend, producing actual token IDs and logprobs for RL training.
|
||||||
|
|
||||||
|
Requires a running vLLM server. Start one from the atropos directory:
|
||||||
|
|
||||||
|
python -m example_trainer.vllm_api_server \
|
||||||
|
--model Qwen/Qwen3-4B-Thinking-2507 \
|
||||||
|
--port 9001 \
|
||||||
|
--gpu-memory-utilization 0.8 \
|
||||||
|
--max-model-len=32000
|
||||||
|
|
||||||
|
Tests are automatically skipped if the server is not reachable.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
pytest tests/test_agent_loop_vllm.py -v
|
||||||
|
pytest tests/test_agent_loop_vllm.py -v -k "single"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from types import SimpleNamespace
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Ensure repo root is importable
|
||||||
|
_repo_root = Path(__file__).resolve().parent.parent
|
||||||
|
if str(_repo_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_repo_root))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("atroposlib not installed", allow_module_level=True)
|
||||||
|
|
||||||
|
|
||||||
def _tool_call(name: str, arguments):
|
# =========================================================================
|
||||||
return SimpleNamespace(
|
# Configuration
|
||||||
id="call_1",
|
# =========================================================================
|
||||||
type="function",
|
|
||||||
function=SimpleNamespace(name=name, arguments=arguments)
|
VLLM_HOST = "localhost"
|
||||||
|
VLLM_PORT = 9001
|
||||||
|
VLLM_BASE_URL = f"http://{VLLM_HOST}:{VLLM_PORT}"
|
||||||
|
VLLM_MODEL = "Qwen/Qwen3-4B-Thinking-2507"
|
||||||
|
|
||||||
|
|
||||||
|
def _vllm_is_running() -> bool:
|
||||||
|
"""Check if the vLLM server is reachable."""
|
||||||
|
try:
|
||||||
|
r = requests.get(f"{VLLM_BASE_URL}/health", timeout=3)
|
||||||
|
return r.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Skip all tests in this module if vLLM is not running
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
not _vllm_is_running(),
|
||||||
|
reason=(
|
||||||
|
f"vLLM server not reachable at {VLLM_BASE_URL}. "
|
||||||
|
"Start it with: python -m example_trainer.vllm_api_server "
|
||||||
|
f"--model {VLLM_MODEL} --port {VLLM_PORT} "
|
||||||
|
"--gpu-memory-utilization 0.8 --max-model-len=32000"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Server setup
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def _make_server_manager():
|
||||||
|
"""Create a ServerManager pointing to the local vLLM server."""
|
||||||
|
from atroposlib.envs.server_handling.server_manager import (
|
||||||
|
ServerManager,
|
||||||
|
APIServerConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config = APIServerConfig(
|
||||||
def _response_with_tool_call(arguments):
|
base_url=VLLM_BASE_URL,
|
||||||
assistant = SimpleNamespace(
|
model_name=VLLM_MODEL,
|
||||||
content=None,
|
server_type="vllm",
|
||||||
reasoning=None,
|
health_check=False,
|
||||||
tool_calls=[_tool_call("read_file", arguments)],
|
|
||||||
)
|
)
|
||||||
choice = SimpleNamespace(message=assistant, finish_reason="tool_calls")
|
sm = ServerManager([config], tool_parser="hermes")
|
||||||
return SimpleNamespace(choices=[choice], usage=None)
|
sm.servers[0].server_healthy = True
|
||||||
|
return sm
|
||||||
|
|
||||||
|
|
||||||
class _FakeChatCompletions:
|
def _get_tokenizer():
|
||||||
def __init__(self):
|
"""Load the tokenizer for the model."""
|
||||||
self.calls = 0
|
from transformers import AutoTokenizer
|
||||||
|
return AutoTokenizer.from_pretrained(VLLM_MODEL)
|
||||||
|
|
||||||
def create(self, **kwargs):
|
|
||||||
self.calls += 1
|
# =========================================================================
|
||||||
if self.calls == 1:
|
# Fake tools
|
||||||
return _response_with_tool_call({"path": "README.md"})
|
# =========================================================================
|
||||||
return SimpleNamespace(
|
|
||||||
choices=[SimpleNamespace(message=SimpleNamespace(content="done", reasoning=None, tool_calls=[]), finish_reason="stop")],
|
WEATHER_TOOL = {
|
||||||
usage=None,
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather for a city. Returns temperature and conditions.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name, e.g. 'Tokyo'",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
CALC_TOOL = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "calculate",
|
||||||
|
"description": "Calculate a math expression. Returns the numeric result.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"expression": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Math expression, e.g. '2 + 3'",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["expression"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_tool_handler(tool_name: str, args: Dict[str, Any], **kwargs) -> str:
|
||||||
|
"""Handle fake tool calls for testing."""
|
||||||
|
if tool_name == "get_weather":
|
||||||
|
city = args.get("city", "Unknown")
|
||||||
|
return json.dumps({
|
||||||
|
"city": city,
|
||||||
|
"temperature": 22,
|
||||||
|
"conditions": "sunny",
|
||||||
|
"humidity": 45,
|
||||||
|
})
|
||||||
|
elif tool_name == "calculate":
|
||||||
|
expr = args.get("expression", "0")
|
||||||
|
try:
|
||||||
|
result = eval(expr, {"__builtins__": {}}, {})
|
||||||
|
return json.dumps({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return json.dumps({"error": str(e)})
|
||||||
|
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Tests
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_single_tool_call():
|
||||||
|
"""vLLM model calls a tool, gets result, responds — full Phase 2 flow."""
|
||||||
|
sm = _make_server_manager()
|
||||||
|
tokenizer = _get_tokenizer()
|
||||||
|
|
||||||
|
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
||||||
|
agent = HermesAgentLoop(
|
||||||
|
server=managed,
|
||||||
|
tool_schemas=[WEATHER_TOOL],
|
||||||
|
valid_tool_names={"get_weather"},
|
||||||
|
max_turns=5,
|
||||||
|
temperature=0.6,
|
||||||
|
max_tokens=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What's the weather in Tokyo? Use the get_weather tool."},
|
||||||
|
]
|
||||||
|
|
||||||
class _FakeClient:
|
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||||
def __init__(self):
|
result = await agent.run(messages)
|
||||||
self.chat = SimpleNamespace(completions=_FakeChatCompletions())
|
|
||||||
|
assert isinstance(result, AgentResult)
|
||||||
|
assert result.turns_used >= 2, f"Expected at least 2 turns, got {result.turns_used}"
|
||||||
|
|
||||||
|
# Verify tool call happened
|
||||||
|
tool_calls_found = False
|
||||||
|
for msg in result.messages:
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
if tc["function"]["name"] == "get_weather":
|
||||||
|
tool_calls_found = True
|
||||||
|
args = json.loads(tc["function"]["arguments"])
|
||||||
|
assert "city" in args
|
||||||
|
assert tool_calls_found, "Model should have called get_weather"
|
||||||
|
|
||||||
|
# Verify tool results in conversation
|
||||||
|
tool_results = [m for m in result.messages if m.get("role") == "tool"]
|
||||||
|
assert len(tool_results) >= 1
|
||||||
|
|
||||||
|
|
||||||
def test_tool_call_validation_accepts_dict_arguments(monkeypatch):
|
@pytest.mark.asyncio
|
||||||
from run_agent import AIAgent
|
async def test_vllm_multi_tool_calls():
|
||||||
|
"""vLLM model calls multiple tools across turns."""
|
||||||
|
sm = _make_server_manager()
|
||||||
|
tokenizer = _get_tokenizer()
|
||||||
|
|
||||||
monkeypatch.setattr("run_agent.OpenAI", lambda **kwargs: _FakeClient())
|
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
||||||
monkeypatch.setattr("run_agent.get_tool_definitions", lambda *args, **kwargs: [{"function": {"name": "read_file"}}])
|
agent = HermesAgentLoop(
|
||||||
monkeypatch.setattr(
|
server=managed,
|
||||||
"run_agent.handle_function_call",
|
tool_schemas=[WEATHER_TOOL, CALC_TOOL],
|
||||||
lambda name, args, task_id=None, **kwargs: json.dumps({"ok": True, "args": args}),
|
valid_tool_names={"get_weather", "calculate"},
|
||||||
|
max_turns=10,
|
||||||
|
temperature=0.6,
|
||||||
|
max_tokens=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": (
|
||||||
|
"I need two things: "
|
||||||
|
"1) What's the weather in Paris? Use get_weather. "
|
||||||
|
"2) What is 15 * 7? Use calculate."
|
||||||
|
)},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||||
|
result = await agent.run(messages)
|
||||||
|
|
||||||
|
# Both tools should be called
|
||||||
|
tools_called = set()
|
||||||
|
for msg in result.messages:
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
tools_called.add(tc["function"]["name"])
|
||||||
|
|
||||||
|
assert "get_weather" in tools_called, f"get_weather not called. Called: {tools_called}"
|
||||||
|
assert "calculate" in tools_called, f"calculate not called. Called: {tools_called}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_managed_server_produces_nodes():
|
||||||
|
"""ManagedServer should produce SequenceNodes with tokens and logprobs."""
|
||||||
|
sm = _make_server_manager()
|
||||||
|
tokenizer = _get_tokenizer()
|
||||||
|
|
||||||
|
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
||||||
|
agent = HermesAgentLoop(
|
||||||
|
server=managed,
|
||||||
|
tool_schemas=[WEATHER_TOOL],
|
||||||
|
valid_tool_names={"get_weather"},
|
||||||
|
max_turns=5,
|
||||||
|
temperature=0.6,
|
||||||
|
max_tokens=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What's the weather in Berlin? Use get_weather."},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||||
|
result = await agent.run(messages)
|
||||||
|
|
||||||
|
# Get the managed state — should have SequenceNodes
|
||||||
|
state = managed.get_state()
|
||||||
|
|
||||||
|
assert state is not None, "ManagedServer should return state"
|
||||||
|
nodes = state.get("nodes", [])
|
||||||
|
assert len(nodes) >= 1, f"Should have at least 1 node, got {len(nodes)}"
|
||||||
|
|
||||||
|
node = nodes[0]
|
||||||
|
assert hasattr(node, "tokens"), "Node should have tokens"
|
||||||
|
assert hasattr(node, "logprobs"), "Node should have logprobs"
|
||||||
|
assert len(node.tokens) > 0, "Tokens should not be empty"
|
||||||
|
assert len(node.logprobs) > 0, "Logprobs should not be empty"
|
||||||
|
assert len(node.tokens) == len(node.logprobs), (
|
||||||
|
f"Tokens ({len(node.tokens)}) and logprobs ({len(node.logprobs)}) should have same length"
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = AIAgent(
|
|
||||||
model="test-model",
|
@pytest.mark.asyncio
|
||||||
api_key="test-key",
|
async def test_vllm_no_tools_direct_response():
|
||||||
base_url="http://localhost:8080/v1",
|
"""vLLM model should respond directly when no tools are needed."""
|
||||||
platform="cli",
|
sm = _make_server_manager()
|
||||||
max_iterations=3,
|
tokenizer = _get_tokenizer()
|
||||||
quiet_mode=True,
|
|
||||||
skip_memory=True,
|
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
||||||
|
agent = HermesAgentLoop(
|
||||||
|
server=managed,
|
||||||
|
tool_schemas=[WEATHER_TOOL],
|
||||||
|
valid_tool_names={"get_weather"},
|
||||||
|
max_turns=5,
|
||||||
|
temperature=0.6,
|
||||||
|
max_tokens=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What is 2 + 2? Answer directly, no tools."},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||||
|
result = await agent.run(messages)
|
||||||
|
|
||||||
|
assert result.finished_naturally, "Should finish naturally"
|
||||||
|
assert result.turns_used == 1, f"Should take 1 turn, took {result.turns_used}"
|
||||||
|
|
||||||
|
final = result.messages[-1]
|
||||||
|
assert final["role"] == "assistant"
|
||||||
|
assert final["content"], "Should have content"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_thinking_content_extracted():
|
||||||
|
"""Qwen3-Thinking model should produce reasoning content."""
|
||||||
|
sm = _make_server_manager()
|
||||||
|
tokenizer = _get_tokenizer()
|
||||||
|
|
||||||
|
async with sm.managed_server(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
preserve_think_blocks=True,
|
||||||
|
) as managed:
|
||||||
|
agent = HermesAgentLoop(
|
||||||
|
server=managed,
|
||||||
|
tool_schemas=[CALC_TOOL],
|
||||||
|
valid_tool_names={"calculate"},
|
||||||
|
max_turns=5,
|
||||||
|
temperature=0.6,
|
||||||
|
max_tokens=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What is 123 * 456? Use the calculate tool."},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||||
|
result = await agent.run(messages)
|
||||||
|
|
||||||
|
# Qwen3-Thinking should generate <think> blocks
|
||||||
|
# Check if any content contains thinking markers
|
||||||
|
has_thinking = False
|
||||||
|
for msg in result.messages:
|
||||||
|
content = msg.get("content", "") or ""
|
||||||
|
if "<think>" in content or "</think>" in content:
|
||||||
|
has_thinking = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Also check reasoning_per_turn
|
||||||
|
has_reasoning = any(r for r in result.reasoning_per_turn if r)
|
||||||
|
|
||||||
|
# At least one of these should be true for a thinking model
|
||||||
|
assert has_thinking or has_reasoning, (
|
||||||
|
"Qwen3-Thinking should produce <think> blocks or reasoning content"
|
||||||
)
|
)
|
||||||
|
|
||||||
result = agent.run_conversation("read the file")
|
|
||||||
|
|
||||||
assert result["final_response"] == "done"
|
|
||||||
|
|
|
||||||
72
tests/test_dict_tool_call_args.py
Normal file
72
tests/test_dict_tool_call_args.py
Normal file
|
|
@ -0,0 +1,72 @@
|
||||||
|
import json
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call(name: str, arguments):
|
||||||
|
return SimpleNamespace(
|
||||||
|
id="call_1",
|
||||||
|
type="function",
|
||||||
|
function=SimpleNamespace(name=name, arguments=arguments),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _response_with_tool_call(arguments):
|
||||||
|
assistant = SimpleNamespace(
|
||||||
|
content=None,
|
||||||
|
reasoning=None,
|
||||||
|
tool_calls=[_tool_call("read_file", arguments)],
|
||||||
|
)
|
||||||
|
choice = SimpleNamespace(message=assistant, finish_reason="tool_calls")
|
||||||
|
return SimpleNamespace(choices=[choice], usage=None)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeChatCompletions:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
def create(self, **kwargs):
|
||||||
|
self.calls += 1
|
||||||
|
if self.calls == 1:
|
||||||
|
return _response_with_tool_call({"path": "README.md"})
|
||||||
|
return SimpleNamespace(
|
||||||
|
choices=[
|
||||||
|
SimpleNamespace(
|
||||||
|
message=SimpleNamespace(content="done", reasoning=None, tool_calls=[]),
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeClient:
|
||||||
|
def __init__(self):
|
||||||
|
self.chat = SimpleNamespace(completions=_FakeChatCompletions())
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_call_validation_accepts_dict_arguments(monkeypatch):
|
||||||
|
from run_agent import AIAgent
|
||||||
|
|
||||||
|
monkeypatch.setattr("run_agent.OpenAI", lambda **kwargs: _FakeClient())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"run_agent.get_tool_definitions",
|
||||||
|
lambda *args, **kwargs: [{"function": {"name": "read_file"}}],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"run_agent.handle_function_call",
|
||||||
|
lambda name, args, task_id=None, **kwargs: json.dumps({"ok": True, "args": args}),
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = AIAgent(
|
||||||
|
model="test-model",
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="http://localhost:8080/v1",
|
||||||
|
platform="cli",
|
||||||
|
max_iterations=3,
|
||||||
|
quiet_mode=True,
|
||||||
|
skip_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = agent.run_conversation("read the file")
|
||||||
|
|
||||||
|
assert result["final_response"] == "done"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue