fix: bring WebResearchEnv up to Atropos environment standards
The environment was merged missing several standard components. Updated to match the patterns established by 82 Atropos environments and our own HermesAgentBaseEnv contract. Added: - WebResearchEnvConfig — custom Pydantic config with reward weights, efficiency thresholds, eval settings, dataset config (all tunable via CLI/YAML without code changes) - config_init() classmethod — default server config (OpenRouter + Claude) so the env works out of the box - wandb_log() override — logs reward breakdown metrics (correctness, tool_usage, efficiency, diversity, correct_rate, tool_usage_rate) with proper buffer management and super() call - evaluate() — uses server.chat_completion instead of broken stub _run_agent_on_item(). Logs via evaluate_log() for lighteval- compatible output. Fixed: - Removed broken _run_agent_on_item() stub that returned empty results - evaluate() now uses server.chat_completion (same pattern as TerminalTestEnv) for actual model evaluation - compute_reward reads tool calls from AgentResult properly - LLM judge uses self.server.chat_completion instead of ctx Reward config is now tunable without code changes: --env.correctness_weight 0.6 --env.tool_usage_weight 0.2 --env.efficiency_weight 0.2 --env.diversity_bonus 0.1 --env.efficient_max_calls 5
This commit is contained in:
parent
5212644861
commit
8eabdefa8a
1 changed files with 270 additions and 144 deletions
|
|
@ -16,21 +16,18 @@ Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
# Phase 1 (OpenAI-compatible server)
|
# Phase 1 (OpenAI-compatible server)
|
||||||
python environments/web_research_env.py serve \
|
python environments/web_research_env.py serve \\
|
||||||
--openai.base_url http://localhost:8000/v1 \
|
--openai.base_url http://localhost:8000/v1 \\
|
||||||
--openai.model_name YourModel \
|
--openai.model_name YourModel \\
|
||||||
--openai.server_type openai
|
--openai.server_type openai
|
||||||
|
|
||||||
# With eval split
|
# Process mode (offline data generation)
|
||||||
python environments/web_research_env.py serve \
|
python environments/web_research_env.py process \\
|
||||||
--openai.base_url http://localhost:8000/v1 \
|
--env.data_path_to_save_groups data/web_research.jsonl
|
||||||
--openai.model_name YourModel \
|
|
||||||
--env.eval_every 50 \
|
|
||||||
--env.eval_size 20
|
|
||||||
|
|
||||||
# Standalone eval (no training server needed)
|
# Standalone eval
|
||||||
python environments/web_research_env.py eval \
|
python environments/web_research_env.py evaluate \\
|
||||||
--openai.base_url http://localhost:8000/v1 \
|
--openai.base_url http://localhost:8000/v1 \\
|
||||||
--openai.model_name YourModel
|
--openai.model_name YourModel
|
||||||
|
|
||||||
Built by: github.com/jackx707
|
Built by: github.com/jackx707
|
||||||
|
|
@ -43,11 +40,21 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import Any, Optional
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
# Ensure hermes-agent root is on path
|
||||||
|
_repo_root = Path(__file__).resolve().parent.parent
|
||||||
|
if str(_repo_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_repo_root))
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Optional HuggingFace datasets import
|
# Optional HuggingFace datasets import
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -57,13 +64,19 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HF_AVAILABLE = False
|
HF_AVAILABLE = False
|
||||||
|
|
||||||
from environments.hermes_base_env import HermesAgentBaseEnv
|
from atroposlib.envs.base import ScoredDataGroup
|
||||||
|
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||||
|
from atroposlib.type_definitions import Item
|
||||||
|
|
||||||
|
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||||
|
from environments.agent_loop import AgentResult
|
||||||
|
from environments.tool_context import ToolContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Fallback sample dataset (used when HuggingFace is unavailable)
|
# Fallback sample dataset (used when HuggingFace is unavailable)
|
||||||
# These are multi-hop questions that require real web search to answer.
|
# Multi-hop questions requiring real web search to answer.
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
SAMPLE_QUESTIONS = [
|
SAMPLE_QUESTIONS = [
|
||||||
{
|
{
|
||||||
|
|
@ -129,6 +142,58 @@ SAMPLE_QUESTIONS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Configuration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class WebResearchEnvConfig(HermesAgentEnvConfig):
|
||||||
|
"""Configuration for the web research RL environment."""
|
||||||
|
|
||||||
|
# Reward weights
|
||||||
|
correctness_weight: float = Field(
|
||||||
|
default=0.6,
|
||||||
|
description="Weight for answer correctness in reward (LLM judge score).",
|
||||||
|
)
|
||||||
|
tool_usage_weight: float = Field(
|
||||||
|
default=0.2,
|
||||||
|
description="Weight for tool usage signal (did the model actually use web tools?).",
|
||||||
|
)
|
||||||
|
efficiency_weight: float = Field(
|
||||||
|
default=0.2,
|
||||||
|
description="Weight for efficiency signal (penalizes excessive tool calls).",
|
||||||
|
)
|
||||||
|
diversity_bonus: float = Field(
|
||||||
|
default=0.1,
|
||||||
|
description="Bonus reward for citing ≥2 distinct domains.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Efficiency thresholds
|
||||||
|
efficient_max_calls: int = Field(
|
||||||
|
default=5,
|
||||||
|
description="Maximum tool calls before efficiency penalty begins.",
|
||||||
|
)
|
||||||
|
heavy_penalty_calls: int = Field(
|
||||||
|
default=10,
|
||||||
|
description="Tool call count where efficiency penalty steepens.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Eval
|
||||||
|
eval_size: int = Field(
|
||||||
|
default=20,
|
||||||
|
description="Number of held-out items for evaluation.",
|
||||||
|
)
|
||||||
|
eval_split_ratio: float = Field(
|
||||||
|
default=0.1,
|
||||||
|
description="Fraction of dataset to hold out for evaluation (0.0–1.0).",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
dataset_name: str = Field(
|
||||||
|
default="google/frames-benchmark",
|
||||||
|
description="HuggingFace dataset name for research questions.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Environment
|
# Environment
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -143,23 +208,60 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
Reward is multi-signal:
|
Reward is multi-signal:
|
||||||
60% — answer correctness (LLM judge)
|
60% — answer correctness (LLM judge)
|
||||||
20% — tool usage (did the model actually search the web?)
|
20% — tool usage (did the model actually search the web?)
|
||||||
20% — efficiency (penalizes >6 tool calls)
|
20% — efficiency (penalizes >5 tool calls)
|
||||||
|
|
||||||
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "web-research"
|
name = "web-research"
|
||||||
|
env_config_cls = WebResearchEnvConfig
|
||||||
|
|
||||||
# Default toolsets for this environment — web + file for saving notes
|
# Default toolsets for this environment — web + file for saving notes
|
||||||
default_toolsets = ["web", "file"]
|
default_toolsets = ["web", "file"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
|
||||||
|
"""Default configuration for the web research environment."""
|
||||||
|
env_config = WebResearchEnvConfig(
|
||||||
|
enabled_toolsets=["web", "file"],
|
||||||
|
max_agent_turns=15,
|
||||||
|
agent_temperature=1.0,
|
||||||
|
system_prompt=(
|
||||||
|
"You are a highly capable research agent. When asked a factual question, "
|
||||||
|
"always use web_search to find current, accurate information before answering. "
|
||||||
|
"Cite at least 2 sources. Be concise and accurate."
|
||||||
|
),
|
||||||
|
group_size=4,
|
||||||
|
total_steps=1000,
|
||||||
|
steps_per_eval=100,
|
||||||
|
use_wandb=True,
|
||||||
|
wandb_name="web-research",
|
||||||
|
)
|
||||||
|
|
||||||
|
server_configs = [
|
||||||
|
APIServerConfig(
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
model_name="anthropic/claude-sonnet-4.5",
|
||||||
|
server_type="openai",
|
||||||
|
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||||
|
health_check=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return env_config, server_configs
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._items: list[dict] = []
|
self._items: list[dict] = []
|
||||||
self._eval_items: list[dict] = []
|
self._eval_items: list[dict] = []
|
||||||
self._index: int = 0
|
self._index: int = 0
|
||||||
self._total_scored: int = 0
|
|
||||||
self._total_reward: float = 0.0
|
# Metrics tracking for wandb
|
||||||
|
self._reward_buffer: list[float] = []
|
||||||
|
self._correctness_buffer: list[float] = []
|
||||||
|
self._tool_usage_buffer: list[float] = []
|
||||||
|
self._efficiency_buffer: list[float] = []
|
||||||
|
self._diversity_buffer: list[float] = []
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 1. Setup — load dataset
|
# 1. Setup — load dataset
|
||||||
|
|
@ -170,7 +272,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
if HF_AVAILABLE:
|
if HF_AVAILABLE:
|
||||||
try:
|
try:
|
||||||
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
||||||
ds = load_dataset("google/frames-benchmark", split="test")
|
ds = load_dataset(self.config.dataset_name, split="test")
|
||||||
self._items = [
|
self._items = [
|
||||||
{
|
{
|
||||||
"question": row["Prompt"],
|
"question": row["Prompt"],
|
||||||
|
|
@ -180,8 +282,11 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
}
|
}
|
||||||
for row in ds
|
for row in ds
|
||||||
]
|
]
|
||||||
# Hold out 10% for eval
|
# Hold out for eval
|
||||||
eval_size = max(20, len(self._items) // 10)
|
eval_size = max(
|
||||||
|
self.config.eval_size,
|
||||||
|
int(len(self._items) * self.config.eval_split_ratio),
|
||||||
|
)
|
||||||
random.shuffle(self._items)
|
random.shuffle(self._items)
|
||||||
self._eval_items = self._items[:eval_size]
|
self._eval_items = self._items[:eval_size]
|
||||||
self._items = self._items[eval_size:]
|
self._items = self._items[eval_size:]
|
||||||
|
|
@ -220,10 +325,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def format_prompt(self, item: dict) -> str:
|
def format_prompt(self, item: dict) -> str:
|
||||||
"""
|
"""Format the research question as a task prompt."""
|
||||||
Format the research question as a task prompt.
|
|
||||||
Instructs the model to use web search and cite sources.
|
|
||||||
"""
|
|
||||||
return (
|
return (
|
||||||
f"Research the following question thoroughly using web search. "
|
f"Research the following question thoroughly using web search. "
|
||||||
f"You MUST search the web to find current, accurate information — "
|
f"You MUST search the web to find current, accurate information — "
|
||||||
|
|
@ -243,27 +345,30 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
async def compute_reward(
|
async def compute_reward(
|
||||||
self,
|
self,
|
||||||
item: dict,
|
item: dict,
|
||||||
result: dict,
|
result: AgentResult,
|
||||||
ctx: Any, # ToolContext
|
ctx: ToolContext,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Multi-signal reward function:
|
Multi-signal reward function:
|
||||||
|
|
||||||
0.6 * correctness — LLM judge comparing answer to ground truth
|
correctness_weight * correctness — LLM judge comparing answer to ground truth
|
||||||
0.2 * tool_used — binary: did the model use web tools?
|
tool_usage_weight * tool_used — binary: did the model use web tools?
|
||||||
0.2 * efficiency — penalizes wasteful tool usage
|
efficiency_weight * efficiency — penalizes wasteful tool usage
|
||||||
+0.1 bonus — source diversity (≥2 distinct domains)
|
+ diversity_bonus — source diversity (≥2 distinct domains)
|
||||||
"""
|
"""
|
||||||
final_response: str = result.get("final_response", "")
|
final_response: str = result.final_response or ""
|
||||||
tools_used: list[str] = result.get("tools_used", [])
|
tools_used: list[str] = [
|
||||||
tool_call_count: int = result.get("tool_call_count", len(tools_used))
|
tc.tool_name for tc in (result.tool_calls or [])
|
||||||
|
] if hasattr(result, "tool_calls") and result.tool_calls else []
|
||||||
|
tool_call_count: int = result.turns_used or len(tools_used)
|
||||||
|
|
||||||
|
cfg = self.config
|
||||||
|
|
||||||
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
||||||
correctness = await self._llm_judge(
|
correctness = await self._llm_judge(
|
||||||
question=item["question"],
|
question=item["question"],
|
||||||
expected=item["answer"],
|
expected=item["answer"],
|
||||||
model_answer=final_response,
|
model_answer=final_response,
|
||||||
ctx=ctx,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ---- Signal 2: Web tool usage --------------------------------
|
# ---- Signal 2: Web tool usage --------------------------------
|
||||||
|
|
@ -271,35 +376,37 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
||||||
|
|
||||||
# ---- Signal 3: Efficiency ------------------------------------
|
# ---- Signal 3: Efficiency ------------------------------------
|
||||||
# Ideal: 2-5 tool calls. Penalise beyond 6, hard cap at 15.
|
if tool_call_count <= cfg.efficient_max_calls:
|
||||||
if tool_call_count <= 5:
|
|
||||||
efficiency = 1.0
|
efficiency = 1.0
|
||||||
elif tool_call_count <= 10:
|
elif tool_call_count <= cfg.heavy_penalty_calls:
|
||||||
efficiency = 1.0 - (tool_call_count - 5) * 0.08
|
efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
|
||||||
else:
|
else:
|
||||||
efficiency = max(0.0, 1.0 - (tool_call_count - 5) * 0.12)
|
efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
|
||||||
|
|
||||||
# ---- Bonus: Source diversity ---------------------------------
|
# ---- Bonus: Source diversity ---------------------------------
|
||||||
domains = self._extract_domains(final_response)
|
domains = self._extract_domains(final_response)
|
||||||
diversity_bonus = 0.1 if len(domains) >= 2 else 0.0
|
diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
|
||||||
|
|
||||||
# ---- Combine ------------------------------------------------
|
# ---- Combine ------------------------------------------------
|
||||||
reward = (
|
reward = (
|
||||||
0.6 * correctness
|
cfg.correctness_weight * correctness
|
||||||
+ 0.2 * tool_used
|
+ cfg.tool_usage_weight * tool_used
|
||||||
+ 0.2 * efficiency
|
+ cfg.efficiency_weight * efficiency
|
||||||
+ diversity_bonus
|
+ diversity
|
||||||
)
|
)
|
||||||
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
||||||
|
|
||||||
# Track running stats
|
# Track for wandb
|
||||||
self._total_scored += 1
|
self._reward_buffer.append(reward)
|
||||||
self._total_reward += reward
|
self._correctness_buffer.append(correctness)
|
||||||
|
self._tool_usage_buffer.append(tool_used)
|
||||||
|
self._efficiency_buffer.append(efficiency)
|
||||||
|
self._diversity_buffer.append(diversity)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Reward breakdown — correctness={correctness:.2f}, "
|
f"Reward breakdown — correctness={correctness:.2f}, "
|
||||||
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
||||||
f"diversity_bonus={diversity_bonus:.1f} → total={reward:.3f}"
|
f"diversity={diversity:.1f} → total={reward:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return reward
|
return reward
|
||||||
|
|
@ -308,68 +415,117 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
# 5. evaluate — run on held-out eval split
|
# 5. evaluate — run on held-out eval split
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def evaluate(
|
async def evaluate(self, *args, **kwargs) -> None:
|
||||||
self,
|
"""Run evaluation on the held-out split using the agent loop."""
|
||||||
*args: Any,
|
import time
|
||||||
eval_size: Optional[int] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Run evaluation on the held-out split.
|
|
||||||
Returns a dict of metrics for logging.
|
|
||||||
"""
|
|
||||||
items = self._eval_items
|
|
||||||
if eval_size:
|
|
||||||
items = items[:eval_size]
|
|
||||||
|
|
||||||
|
items = self._eval_items
|
||||||
if not items:
|
if not items:
|
||||||
logger.warning("No eval items available.")
|
logger.warning("No eval items available.")
|
||||||
return {}
|
return
|
||||||
|
|
||||||
logger.info(f"Running eval on {len(items)} questions...")
|
eval_size = min(self.config.eval_size, len(items))
|
||||||
|
eval_items = items[:eval_size]
|
||||||
|
|
||||||
rewards = []
|
logger.info(f"Running eval on {len(eval_items)} questions...")
|
||||||
correctness_scores = []
|
start_time = time.time()
|
||||||
|
samples = []
|
||||||
|
|
||||||
for item in items:
|
for item in eval_items:
|
||||||
try:
|
try:
|
||||||
# Run the agent on each eval question
|
# Use the base env's agent loop for eval (same as training)
|
||||||
result = await self._run_agent_on_item(item)
|
prompt = self.format_prompt(item)
|
||||||
reward = await self.compute_reward(item, result, ctx=None)
|
completion = await self.server.chat_completion(
|
||||||
rewards.append(reward)
|
messages=[
|
||||||
|
{"role": "system", "content": self.config.system_prompt or ""},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
n=1,
|
||||||
|
max_tokens=self.config.max_token_length,
|
||||||
|
temperature=0.0,
|
||||||
|
split="eval",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_content = (
|
||||||
|
completion.choices[0].message.content if completion.choices else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score the response
|
||||||
|
correctness = await self._llm_judge(
|
||||||
|
question=item["question"],
|
||||||
|
expected=item["answer"],
|
||||||
|
model_answer=response_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
samples.append({
|
||||||
|
"prompt": item["question"],
|
||||||
|
"response": response_content,
|
||||||
|
"expected": item["answer"],
|
||||||
|
"correctness": correctness,
|
||||||
|
})
|
||||||
|
|
||||||
# Also track raw correctness separately
|
|
||||||
if result.get("final_response"):
|
|
||||||
correctness_scores.append(
|
|
||||||
await self._llm_judge(
|
|
||||||
question=item["question"],
|
|
||||||
expected=item["answer"],
|
|
||||||
model_answer=result["final_response"],
|
|
||||||
ctx=None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Eval error on item: {e}")
|
logger.error(f"Eval error on item: {e}")
|
||||||
rewards.append(0.0)
|
samples.append({
|
||||||
|
"prompt": item["question"],
|
||||||
|
"response": f"ERROR: {e}",
|
||||||
|
"expected": item["answer"],
|
||||||
|
"correctness": 0.0,
|
||||||
|
})
|
||||||
|
|
||||||
metrics = {
|
end_time = time.time()
|
||||||
"eval/mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
|
|
||||||
|
# Compute metrics
|
||||||
|
correctness_scores = [s["correctness"] for s in samples]
|
||||||
|
eval_metrics = {
|
||||||
"eval/mean_correctness": (
|
"eval/mean_correctness": (
|
||||||
sum(correctness_scores) / len(correctness_scores)
|
sum(correctness_scores) / len(correctness_scores)
|
||||||
if correctness_scores else 0.0
|
if correctness_scores else 0.0
|
||||||
),
|
),
|
||||||
"eval/n_items": len(rewards),
|
"eval/n_items": len(samples),
|
||||||
"train/mean_reward_so_far": (
|
|
||||||
self._total_reward / self._total_scored
|
|
||||||
if self._total_scored > 0 else 0.0
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
await self.evaluate_log(
|
||||||
f"Eval complete — mean_reward={metrics['eval/mean_reward']:.3f}, "
|
metrics=eval_metrics,
|
||||||
f"mean_correctness={metrics['eval/mean_correctness']:.3f}"
|
samples=samples,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
return metrics
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 6. wandb_log — custom metrics
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
||||||
|
"""Log reward breakdown metrics to wandb."""
|
||||||
|
if wandb_metrics is None:
|
||||||
|
wandb_metrics = {}
|
||||||
|
|
||||||
|
if self._reward_buffer:
|
||||||
|
n = len(self._reward_buffer)
|
||||||
|
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||||
|
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
|
||||||
|
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
|
||||||
|
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
|
||||||
|
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
|
||||||
|
wandb_metrics["train/total_rollouts"] = n
|
||||||
|
|
||||||
|
# Accuracy buckets
|
||||||
|
wandb_metrics["train/correct_rate"] = (
|
||||||
|
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
|
||||||
|
)
|
||||||
|
wandb_metrics["train/tool_usage_rate"] = (
|
||||||
|
sum(1 for t in self._tool_usage_buffer if t > 0) / n
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear buffers
|
||||||
|
self._reward_buffer.clear()
|
||||||
|
self._correctness_buffer.clear()
|
||||||
|
self._tool_usage_buffer.clear()
|
||||||
|
self._efficiency_buffer.clear()
|
||||||
|
self._diversity_buffer.clear()
|
||||||
|
|
||||||
|
await super().wandb_log(wandb_metrics)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Private helpers
|
# Private helpers
|
||||||
|
|
@ -380,19 +536,14 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
question: str,
|
question: str,
|
||||||
expected: str,
|
expected: str,
|
||||||
model_answer: str,
|
model_answer: str,
|
||||||
ctx: Any,
|
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Use an LLM to judge whether `model_answer` correctly addresses
|
Use the server's LLM to judge answer correctness.
|
||||||
`question` compared to `expected`. Returns a float in [0, 1].
|
Falls back to keyword heuristic if LLM call fails.
|
||||||
|
|
||||||
Uses the agent's own inference client if ctx is available,
|
|
||||||
otherwise falls back to a lightweight heuristic.
|
|
||||||
"""
|
"""
|
||||||
if not model_answer or not model_answer.strip():
|
if not model_answer or not model_answer.strip():
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# Build judge prompt
|
|
||||||
judge_prompt = (
|
judge_prompt = (
|
||||||
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
||||||
f"Question: {question}\n\n"
|
f"Question: {question}\n\n"
|
||||||
|
|
@ -405,39 +556,36 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
||||||
" 0.0 = completely wrong or no answer\n\n"
|
" 0.0 = completely wrong or no answer\n\n"
|
||||||
"Consider: factual accuracy, completeness, and relevance.\n"
|
"Consider: factual accuracy, completeness, and relevance.\n"
|
||||||
"Respond with ONLY a JSON object: {\"score\": <float>, \"reason\": \"<one sentence>\"}"
|
'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try using ctx for inference (Phase 2 / live training)
|
try:
|
||||||
if ctx is not None and hasattr(ctx, "chat_completion"):
|
response = await self.server.chat_completion(
|
||||||
try:
|
messages=[{"role": "user", "content": judge_prompt}],
|
||||||
response = await ctx.chat_completion(
|
n=1,
|
||||||
messages=[{"role": "user", "content": judge_prompt}],
|
max_tokens=150,
|
||||||
max_tokens=100,
|
temperature=0.0,
|
||||||
temperature=0.0,
|
split="eval",
|
||||||
)
|
)
|
||||||
text = response.get("content", "")
|
text = response.choices[0].message.content if response.choices else ""
|
||||||
parsed = self._parse_judge_json(text)
|
parsed = self._parse_judge_json(text)
|
||||||
if parsed is not None:
|
if parsed is not None:
|
||||||
return float(parsed)
|
return float(parsed)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"LLM judge via ctx failed: {e}. Using heuristic.")
|
logger.debug(f"LLM judge failed: {e}. Using heuristic.")
|
||||||
|
|
||||||
# Fallback: keyword overlap heuristic
|
|
||||||
return self._heuristic_score(expected, model_answer)
|
return self._heuristic_score(expected, model_answer)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_judge_json(text: str) -> Optional[float]:
|
def _parse_judge_json(text: str) -> Optional[float]:
|
||||||
"""Extract the score float from LLM judge JSON response."""
|
"""Extract the score float from LLM judge JSON response."""
|
||||||
try:
|
try:
|
||||||
# Strip markdown code fences if present
|
|
||||||
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
||||||
data = json.loads(clean)
|
data = json.loads(clean)
|
||||||
score = float(data.get("score", -1))
|
score = float(data.get("score", -1))
|
||||||
if 0.0 <= score <= 1.0:
|
if 0.0 <= score <= 1.0:
|
||||||
return score
|
return score
|
||||||
except Exception:
|
except Exception:
|
||||||
# Try regex fallback
|
|
||||||
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
||||||
if match:
|
if match:
|
||||||
score = float(match.group(1))
|
score = float(match.group(1))
|
||||||
|
|
@ -447,10 +595,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _heuristic_score(expected: str, model_answer: str) -> float:
|
def _heuristic_score(expected: str, model_answer: str) -> float:
|
||||||
"""
|
"""Lightweight keyword overlap score as fallback."""
|
||||||
Lightweight keyword overlap score as fallback when no LLM is available.
|
|
||||||
Extracts meaningful tokens and computes Jaccard similarity.
|
|
||||||
"""
|
|
||||||
stopwords = {
|
stopwords = {
|
||||||
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
||||||
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
||||||
|
|
@ -458,35 +603,30 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
}
|
}
|
||||||
|
|
||||||
def tokenize(text: str) -> set:
|
def tokenize(text: str) -> set:
|
||||||
tokens = re.findall(r'\b[a-zA-Z0-9]+\b', text.lower())
|
tokens = re.findall(r'\b\w+\b', text.lower())
|
||||||
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
||||||
|
|
||||||
expected_tokens = tokenize(expected)
|
expected_tokens = tokenize(expected)
|
||||||
answer_tokens = tokenize(model_answer)
|
answer_tokens = tokenize(model_answer)
|
||||||
|
|
||||||
if not expected_tokens:
|
if not expected_tokens:
|
||||||
return 0.5 # Can't judge
|
return 0.5
|
||||||
|
|
||||||
overlap = len(expected_tokens & answer_tokens)
|
overlap = len(expected_tokens & answer_tokens)
|
||||||
union = len(expected_tokens | answer_tokens)
|
union = len(expected_tokens | answer_tokens)
|
||||||
|
|
||||||
jaccard = overlap / union if union > 0 else 0.0
|
jaccard = overlap / union if union > 0 else 0.0
|
||||||
# Recall-weighted: reward covering expected content
|
|
||||||
recall = overlap / len(expected_tokens)
|
recall = overlap / len(expected_tokens)
|
||||||
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_domains(text: str) -> set:
|
def _extract_domains(text: str) -> set:
|
||||||
"""
|
"""Extract unique domains from URLs cited in the response."""
|
||||||
Extract unique domains from URLs cited in the response.
|
|
||||||
Used to measure source diversity.
|
|
||||||
"""
|
|
||||||
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
||||||
domains = set()
|
domains = set()
|
||||||
for url in urls:
|
for url in urls:
|
||||||
try:
|
try:
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
# Normalize: strip www.
|
|
||||||
domain = parsed.netloc.lower().lstrip("www.")
|
domain = parsed.netloc.lower().lstrip("www.")
|
||||||
if domain:
|
if domain:
|
||||||
domains.add(domain)
|
domains.add(domain)
|
||||||
|
|
@ -494,20 +634,6 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
pass
|
pass
|
||||||
return domains
|
return domains
|
||||||
|
|
||||||
async def _run_agent_on_item(self, item: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Stub for running agent during eval. In Phase 1/2, this is handled
|
|
||||||
by the Atropos framework's rollout mechanism. Provided here for
|
|
||||||
standalone eval compatibility.
|
|
||||||
"""
|
|
||||||
# In real usage, the framework calls get_next_item + format_prompt
|
|
||||||
# and runs the agent. This stub returns an empty result for safety.
|
|
||||||
return {
|
|
||||||
"final_response": "",
|
|
||||||
"tools_used": [],
|
|
||||||
"tool_call_count": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Entry point
|
# Entry point
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue