Merge: WebResearchEnv Atropos standards compliance
This commit is contained in:
commit
8bc0d4f77d
1 changed files with 270 additions and 144 deletions
|
|
@ -16,21 +16,18 @@ Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
# Phase 1 (OpenAI-compatible server)
|
# Phase 1 (OpenAI-compatible server)
|
||||||
python environments/web_research_env.py serve \
|
python environments/web_research_env.py serve \\
|
||||||
--openai.base_url http://localhost:8000/v1 \
|
--openai.base_url http://localhost:8000/v1 \\
|
||||||
--openai.model_name YourModel \
|
--openai.model_name YourModel \\
|
||||||
--openai.server_type openai
|
--openai.server_type openai
|
||||||
|
|
||||||
# With eval split
|
# Process mode (offline data generation)
|
||||||
python environments/web_research_env.py serve \
|
python environments/web_research_env.py process \\
|
||||||
--openai.base_url http://localhost:8000/v1 \
|
--env.data_path_to_save_groups data/web_research.jsonl
|
||||||
--openai.model_name YourModel \
|
|
||||||
--env.eval_every 50 \
|
|
||||||
--env.eval_size 20
|
|
||||||
|
|
||||||
# Standalone eval (no training server needed)
|
# Standalone eval
|
||||||
python environments/web_research_env.py eval \
|
python environments/web_research_env.py evaluate \\
|
||||||
--openai.base_url http://localhost:8000/v1 \
|
--openai.base_url http://localhost:8000/v1 \\
|
||||||
--openai.model_name YourModel
|
--openai.model_name YourModel
|
||||||
|
|
||||||
Built by: github.com/jackx707
|
Built by: github.com/jackx707
|
||||||
|
|
@ -43,11 +40,21 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import Any, Optional
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
# Ensure hermes-agent root is on path
|
||||||
|
_repo_root = Path(__file__).resolve().parent.parent
|
||||||
|
if str(_repo_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_repo_root))
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Optional HuggingFace datasets import
|
# Optional HuggingFace datasets import
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -57,13 +64,19 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HF_AVAILABLE = False
|
HF_AVAILABLE = False
|
||||||
|
|
||||||
from environments.hermes_base_env import HermesAgentBaseEnv
|
from atroposlib.envs.base import ScoredDataGroup
|
||||||
|
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||||
|
from atroposlib.type_definitions import Item
|
||||||
|
|
||||||
|
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||||
|
from environments.agent_loop import AgentResult
|
||||||
|
from environments.tool_context import ToolContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Fallback sample dataset (used when HuggingFace is unavailable)
|
# Fallback sample dataset (used when HuggingFace is unavailable)
|
||||||
# These are multi-hop questions that require real web search to answer.
|
# Multi-hop questions requiring real web search to answer.
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
SAMPLE_QUESTIONS = [
|
SAMPLE_QUESTIONS = [
|
||||||
{
|
{
|
||||||
|
|
@ -129,6 +142,58 @@ SAMPLE_QUESTIONS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Configuration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class WebResearchEnvConfig(HermesAgentEnvConfig):
|
||||||
|
"""Configuration for the web research RL environment."""
|
||||||
|
|
||||||
|
# Reward weights
|
||||||
|
correctness_weight: float = Field(
|
||||||
|
default=0.6,
|
||||||
|
description="Weight for answer correctness in reward (LLM judge score).",
|
||||||
|
)
|
||||||
|
tool_usage_weight: float = Field(
|
||||||
|
default=0.2,
|
||||||
|
description="Weight for tool usage signal (did the model actually use web tools?).",
|
||||||
|
)
|
||||||
|
efficiency_weight: float = Field(
|
||||||
|
default=0.2,
|
||||||
|
description="Weight for efficiency signal (penalizes excessive tool calls).",
|
||||||
|
)
|
||||||
|
diversity_bonus: float = Field(
|
||||||
|
default=0.1,
|
||||||
|
description="Bonus reward for citing ≥2 distinct domains.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Efficiency thresholds
|
||||||
|
efficient_max_calls: int = Field(
|
||||||
|
default=5,
|
||||||
|
description="Maximum tool calls before efficiency penalty begins.",
|
||||||
|
)
|
||||||
|
heavy_penalty_calls: int = Field(
|
||||||
|
default=10,
|
||||||
|
description="Tool call count where efficiency penalty steepens.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Eval
|
||||||
|
eval_size: int = Field(
|
||||||
|
default=20,
|
||||||
|
description="Number of held-out items for evaluation.",
|
||||||
|
)
|
||||||
|
eval_split_ratio: float = Field(
|
||||||
|
default=0.1,
|
||||||
|
description="Fraction of dataset to hold out for evaluation (0.0–1.0).",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
dataset_name: str = Field(
|
||||||
|
default="google/frames-benchmark",
|
||||||
|
description="HuggingFace dataset name for research questions.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Environment
|
# Environment
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -143,23 +208,60 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
Reward is multi-signal:
|
Reward is multi-signal:
|
||||||
60% — answer correctness (LLM judge)
|
60% — answer correctness (LLM judge)
|
||||||
20% — tool usage (did the model actually search the web?)
|
20% — tool usage (did the model actually search the web?)
|
||||||
20% — efficiency (penalizes >6 tool calls)
|
20% — efficiency (penalizes >5 tool calls)
|
||||||
|
|
||||||
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "web-research"
|
name = "web-research"
|
||||||
|
env_config_cls = WebResearchEnvConfig
|
||||||
|
|
||||||
# Default toolsets for this environment — web + file for saving notes
|
# Default toolsets for this environment — web + file for saving notes
|
||||||
default_toolsets = ["web", "file"]
|
default_toolsets = ["web", "file"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
|
||||||
|
"""Default configuration for the web research environment."""
|
||||||
|
env_config = WebResearchEnvConfig(
|
||||||
|
enabled_toolsets=["web", "file"],
|
||||||
|
max_agent_turns=15,
|
||||||
|
agent_temperature=1.0,
|
||||||
|
system_prompt=(
|
||||||
|
"You are a highly capable research agent. When asked a factual question, "
|
||||||
|
"always use web_search to find current, accurate information before answering. "
|
||||||
|
"Cite at least 2 sources. Be concise and accurate."
|
||||||
|
),
|
||||||
|
group_size=4,
|
||||||
|
total_steps=1000,
|
||||||
|
steps_per_eval=100,
|
||||||
|
use_wandb=True,
|
||||||
|
wandb_name="web-research",
|
||||||
|
)
|
||||||
|
|
||||||
|
server_configs = [
|
||||||
|
APIServerConfig(
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
model_name="anthropic/claude-sonnet-4.5",
|
||||||
|
server_type="openai",
|
||||||
|
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||||
|
health_check=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return env_config, server_configs
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._items: list[dict] = []
|
self._items: list[dict] = []
|
||||||
self._eval_items: list[dict] = []
|
self._eval_items: list[dict] = []
|
||||||
self._index: int = 0
|
self._index: int = 0
|
||||||
self._total_scored: int = 0
|
|
||||||
self._total_reward: float = 0.0
|
# Metrics tracking for wandb
|
||||||
|
self._reward_buffer: list[float] = []
|
||||||
|
self._correctness_buffer: list[float] = []
|
||||||
|
self._tool_usage_buffer: list[float] = []
|
||||||
|
self._efficiency_buffer: list[float] = []
|
||||||
|
self._diversity_buffer: list[float] = []
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 1. Setup — load dataset
|
# 1. Setup — load dataset
|
||||||
|
|
@ -170,7 +272,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
if HF_AVAILABLE:
|
if HF_AVAILABLE:
|
||||||
try:
|
try:
|
||||||
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
||||||
ds = load_dataset("google/frames-benchmark", split="test")
|
ds = load_dataset(self.config.dataset_name, split="test")
|
||||||
self._items = [
|
self._items = [
|
||||||
{
|
{
|
||||||
"question": row["Prompt"],
|
"question": row["Prompt"],
|
||||||
|
|
@ -180,8 +282,11 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
}
|
}
|
||||||
for row in ds
|
for row in ds
|
||||||
]
|
]
|
||||||
# Hold out 10% for eval
|
# Hold out for eval
|
||||||
eval_size = max(20, len(self._items) // 10)
|
eval_size = max(
|
||||||
|
self.config.eval_size,
|
||||||
|
int(len(self._items) * self.config.eval_split_ratio),
|
||||||
|
)
|
||||||
random.shuffle(self._items)
|
random.shuffle(self._items)
|
||||||
self._eval_items = self._items[:eval_size]
|
self._eval_items = self._items[:eval_size]
|
||||||
self._items = self._items[eval_size:]
|
self._items = self._items[eval_size:]
|
||||||
|
|
@ -220,10 +325,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def format_prompt(self, item: dict) -> str:
|
def format_prompt(self, item: dict) -> str:
|
||||||
"""
|
"""Format the research question as a task prompt."""
|
||||||
Format the research question as a task prompt.
|
|
||||||
Instructs the model to use web search and cite sources.
|
|
||||||
"""
|
|
||||||
return (
|
return (
|
||||||
f"Research the following question thoroughly using web search. "
|
f"Research the following question thoroughly using web search. "
|
||||||
f"You MUST search the web to find current, accurate information — "
|
f"You MUST search the web to find current, accurate information — "
|
||||||
|
|
@ -243,27 +345,30 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
async def compute_reward(
|
async def compute_reward(
|
||||||
self,
|
self,
|
||||||
item: dict,
|
item: dict,
|
||||||
result: dict,
|
result: AgentResult,
|
||||||
ctx: Any, # ToolContext
|
ctx: ToolContext,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Multi-signal reward function:
|
Multi-signal reward function:
|
||||||
|
|
||||||
0.6 * correctness — LLM judge comparing answer to ground truth
|
correctness_weight * correctness — LLM judge comparing answer to ground truth
|
||||||
0.2 * tool_used — binary: did the model use web tools?
|
tool_usage_weight * tool_used — binary: did the model use web tools?
|
||||||
0.2 * efficiency — penalizes wasteful tool usage
|
efficiency_weight * efficiency — penalizes wasteful tool usage
|
||||||
+0.1 bonus — source diversity (≥2 distinct domains)
|
+ diversity_bonus — source diversity (≥2 distinct domains)
|
||||||
"""
|
"""
|
||||||
final_response: str = result.get("final_response", "")
|
final_response: str = result.final_response or ""
|
||||||
tools_used: list[str] = result.get("tools_used", [])
|
tools_used: list[str] = [
|
||||||
tool_call_count: int = result.get("tool_call_count", len(tools_used))
|
tc.tool_name for tc in (result.tool_calls or [])
|
||||||
|
] if hasattr(result, "tool_calls") and result.tool_calls else []
|
||||||
|
tool_call_count: int = result.turns_used or len(tools_used)
|
||||||
|
|
||||||
|
cfg = self.config
|
||||||
|
|
||||||
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
||||||
correctness = await self._llm_judge(
|
correctness = await self._llm_judge(
|
||||||
question=item["question"],
|
question=item["question"],
|
||||||
expected=item["answer"],
|
expected=item["answer"],
|
||||||
model_answer=final_response,
|
model_answer=final_response,
|
||||||
ctx=ctx,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ---- Signal 2: Web tool usage --------------------------------
|
# ---- Signal 2: Web tool usage --------------------------------
|
||||||
|
|
@ -271,35 +376,37 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
||||||
|
|
||||||
# ---- Signal 3: Efficiency ------------------------------------
|
# ---- Signal 3: Efficiency ------------------------------------
|
||||||
# Ideal: 2-5 tool calls. Penalise beyond 6, hard cap at 15.
|
if tool_call_count <= cfg.efficient_max_calls:
|
||||||
if tool_call_count <= 5:
|
|
||||||
efficiency = 1.0
|
efficiency = 1.0
|
||||||
elif tool_call_count <= 10:
|
elif tool_call_count <= cfg.heavy_penalty_calls:
|
||||||
efficiency = 1.0 - (tool_call_count - 5) * 0.08
|
efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
|
||||||
else:
|
else:
|
||||||
efficiency = max(0.0, 1.0 - (tool_call_count - 5) * 0.12)
|
efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
|
||||||
|
|
||||||
# ---- Bonus: Source diversity ---------------------------------
|
# ---- Bonus: Source diversity ---------------------------------
|
||||||
domains = self._extract_domains(final_response)
|
domains = self._extract_domains(final_response)
|
||||||
diversity_bonus = 0.1 if len(domains) >= 2 else 0.0
|
diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
|
||||||
|
|
||||||
# ---- Combine ------------------------------------------------
|
# ---- Combine ------------------------------------------------
|
||||||
reward = (
|
reward = (
|
||||||
0.6 * correctness
|
cfg.correctness_weight * correctness
|
||||||
+ 0.2 * tool_used
|
+ cfg.tool_usage_weight * tool_used
|
||||||
+ 0.2 * efficiency
|
+ cfg.efficiency_weight * efficiency
|
||||||
+ diversity_bonus
|
+ diversity
|
||||||
)
|
)
|
||||||
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
||||||
|
|
||||||
# Track running stats
|
# Track for wandb
|
||||||
self._total_scored += 1
|
self._reward_buffer.append(reward)
|
||||||
self._total_reward += reward
|
self._correctness_buffer.append(correctness)
|
||||||
|
self._tool_usage_buffer.append(tool_used)
|
||||||
|
self._efficiency_buffer.append(efficiency)
|
||||||
|
self._diversity_buffer.append(diversity)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Reward breakdown — correctness={correctness:.2f}, "
|
f"Reward breakdown — correctness={correctness:.2f}, "
|
||||||
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
||||||
f"diversity_bonus={diversity_bonus:.1f} → total={reward:.3f}"
|
f"diversity={diversity:.1f} → total={reward:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return reward
|
return reward
|
||||||
|
|
@ -308,68 +415,117 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
# 5. evaluate — run on held-out eval split
|
# 5. evaluate — run on held-out eval split
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def evaluate(
|
async def evaluate(self, *args, **kwargs) -> None:
|
||||||
self,
|
"""Run evaluation on the held-out split using the agent loop."""
|
||||||
*args: Any,
|
import time
|
||||||
eval_size: Optional[int] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Run evaluation on the held-out split.
|
|
||||||
Returns a dict of metrics for logging.
|
|
||||||
"""
|
|
||||||
items = self._eval_items
|
|
||||||
if eval_size:
|
|
||||||
items = items[:eval_size]
|
|
||||||
|
|
||||||
|
items = self._eval_items
|
||||||
if not items:
|
if not items:
|
||||||
logger.warning("No eval items available.")
|
logger.warning("No eval items available.")
|
||||||
return {}
|
return
|
||||||
|
|
||||||
logger.info(f"Running eval on {len(items)} questions...")
|
eval_size = min(self.config.eval_size, len(items))
|
||||||
|
eval_items = items[:eval_size]
|
||||||
|
|
||||||
rewards = []
|
logger.info(f"Running eval on {len(eval_items)} questions...")
|
||||||
correctness_scores = []
|
start_time = time.time()
|
||||||
|
samples = []
|
||||||
|
|
||||||
for item in items:
|
for item in eval_items:
|
||||||
try:
|
try:
|
||||||
# Run the agent on each eval question
|
# Use the base env's agent loop for eval (same as training)
|
||||||
result = await self._run_agent_on_item(item)
|
prompt = self.format_prompt(item)
|
||||||
reward = await self.compute_reward(item, result, ctx=None)
|
completion = await self.server.chat_completion(
|
||||||
rewards.append(reward)
|
messages=[
|
||||||
|
{"role": "system", "content": self.config.system_prompt or ""},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
n=1,
|
||||||
|
max_tokens=self.config.max_token_length,
|
||||||
|
temperature=0.0,
|
||||||
|
split="eval",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_content = (
|
||||||
|
completion.choices[0].message.content if completion.choices else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score the response
|
||||||
|
correctness = await self._llm_judge(
|
||||||
|
question=item["question"],
|
||||||
|
expected=item["answer"],
|
||||||
|
model_answer=response_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
samples.append({
|
||||||
|
"prompt": item["question"],
|
||||||
|
"response": response_content,
|
||||||
|
"expected": item["answer"],
|
||||||
|
"correctness": correctness,
|
||||||
|
})
|
||||||
|
|
||||||
# Also track raw correctness separately
|
|
||||||
if result.get("final_response"):
|
|
||||||
correctness_scores.append(
|
|
||||||
await self._llm_judge(
|
|
||||||
question=item["question"],
|
|
||||||
expected=item["answer"],
|
|
||||||
model_answer=result["final_response"],
|
|
||||||
ctx=None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Eval error on item: {e}")
|
logger.error(f"Eval error on item: {e}")
|
||||||
rewards.append(0.0)
|
samples.append({
|
||||||
|
"prompt": item["question"],
|
||||||
|
"response": f"ERROR: {e}",
|
||||||
|
"expected": item["answer"],
|
||||||
|
"correctness": 0.0,
|
||||||
|
})
|
||||||
|
|
||||||
metrics = {
|
end_time = time.time()
|
||||||
"eval/mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
|
|
||||||
|
# Compute metrics
|
||||||
|
correctness_scores = [s["correctness"] for s in samples]
|
||||||
|
eval_metrics = {
|
||||||
"eval/mean_correctness": (
|
"eval/mean_correctness": (
|
||||||
sum(correctness_scores) / len(correctness_scores)
|
sum(correctness_scores) / len(correctness_scores)
|
||||||
if correctness_scores else 0.0
|
if correctness_scores else 0.0
|
||||||
),
|
),
|
||||||
"eval/n_items": len(rewards),
|
"eval/n_items": len(samples),
|
||||||
"train/mean_reward_so_far": (
|
|
||||||
self._total_reward / self._total_scored
|
|
||||||
if self._total_scored > 0 else 0.0
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
await self.evaluate_log(
|
||||||
f"Eval complete — mean_reward={metrics['eval/mean_reward']:.3f}, "
|
metrics=eval_metrics,
|
||||||
f"mean_correctness={metrics['eval/mean_correctness']:.3f}"
|
samples=samples,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
return metrics
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 6. wandb_log — custom metrics
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
||||||
|
"""Log reward breakdown metrics to wandb."""
|
||||||
|
if wandb_metrics is None:
|
||||||
|
wandb_metrics = {}
|
||||||
|
|
||||||
|
if self._reward_buffer:
|
||||||
|
n = len(self._reward_buffer)
|
||||||
|
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||||
|
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
|
||||||
|
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
|
||||||
|
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
|
||||||
|
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
|
||||||
|
wandb_metrics["train/total_rollouts"] = n
|
||||||
|
|
||||||
|
# Accuracy buckets
|
||||||
|
wandb_metrics["train/correct_rate"] = (
|
||||||
|
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
|
||||||
|
)
|
||||||
|
wandb_metrics["train/tool_usage_rate"] = (
|
||||||
|
sum(1 for t in self._tool_usage_buffer if t > 0) / n
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear buffers
|
||||||
|
self._reward_buffer.clear()
|
||||||
|
self._correctness_buffer.clear()
|
||||||
|
self._tool_usage_buffer.clear()
|
||||||
|
self._efficiency_buffer.clear()
|
||||||
|
self._diversity_buffer.clear()
|
||||||
|
|
||||||
|
await super().wandb_log(wandb_metrics)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Private helpers
|
# Private helpers
|
||||||
|
|
@ -380,19 +536,14 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
question: str,
|
question: str,
|
||||||
expected: str,
|
expected: str,
|
||||||
model_answer: str,
|
model_answer: str,
|
||||||
ctx: Any,
|
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Use an LLM to judge whether `model_answer` correctly addresses
|
Use the server's LLM to judge answer correctness.
|
||||||
`question` compared to `expected`. Returns a float in [0, 1].
|
Falls back to keyword heuristic if LLM call fails.
|
||||||
|
|
||||||
Uses the agent's own inference client if ctx is available,
|
|
||||||
otherwise falls back to a lightweight heuristic.
|
|
||||||
"""
|
"""
|
||||||
if not model_answer or not model_answer.strip():
|
if not model_answer or not model_answer.strip():
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# Build judge prompt
|
|
||||||
judge_prompt = (
|
judge_prompt = (
|
||||||
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
||||||
f"Question: {question}\n\n"
|
f"Question: {question}\n\n"
|
||||||
|
|
@ -405,39 +556,36 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
||||||
" 0.0 = completely wrong or no answer\n\n"
|
" 0.0 = completely wrong or no answer\n\n"
|
||||||
"Consider: factual accuracy, completeness, and relevance.\n"
|
"Consider: factual accuracy, completeness, and relevance.\n"
|
||||||
"Respond with ONLY a JSON object: {\"score\": <float>, \"reason\": \"<one sentence>\"}"
|
'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try using ctx for inference (Phase 2 / live training)
|
try:
|
||||||
if ctx is not None and hasattr(ctx, "chat_completion"):
|
response = await self.server.chat_completion(
|
||||||
try:
|
messages=[{"role": "user", "content": judge_prompt}],
|
||||||
response = await ctx.chat_completion(
|
n=1,
|
||||||
messages=[{"role": "user", "content": judge_prompt}],
|
max_tokens=150,
|
||||||
max_tokens=100,
|
temperature=0.0,
|
||||||
temperature=0.0,
|
split="eval",
|
||||||
)
|
)
|
||||||
text = response.get("content", "")
|
text = response.choices[0].message.content if response.choices else ""
|
||||||
parsed = self._parse_judge_json(text)
|
parsed = self._parse_judge_json(text)
|
||||||
if parsed is not None:
|
if parsed is not None:
|
||||||
return float(parsed)
|
return float(parsed)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"LLM judge via ctx failed: {e}. Using heuristic.")
|
logger.debug(f"LLM judge failed: {e}. Using heuristic.")
|
||||||
|
|
||||||
# Fallback: keyword overlap heuristic
|
|
||||||
return self._heuristic_score(expected, model_answer)
|
return self._heuristic_score(expected, model_answer)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_judge_json(text: str) -> Optional[float]:
|
def _parse_judge_json(text: str) -> Optional[float]:
|
||||||
"""Extract the score float from LLM judge JSON response."""
|
"""Extract the score float from LLM judge JSON response."""
|
||||||
try:
|
try:
|
||||||
# Strip markdown code fences if present
|
|
||||||
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
||||||
data = json.loads(clean)
|
data = json.loads(clean)
|
||||||
score = float(data.get("score", -1))
|
score = float(data.get("score", -1))
|
||||||
if 0.0 <= score <= 1.0:
|
if 0.0 <= score <= 1.0:
|
||||||
return score
|
return score
|
||||||
except Exception:
|
except Exception:
|
||||||
# Try regex fallback
|
|
||||||
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
||||||
if match:
|
if match:
|
||||||
score = float(match.group(1))
|
score = float(match.group(1))
|
||||||
|
|
@ -447,10 +595,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _heuristic_score(expected: str, model_answer: str) -> float:
|
def _heuristic_score(expected: str, model_answer: str) -> float:
|
||||||
"""
|
"""Lightweight keyword overlap score as fallback."""
|
||||||
Lightweight keyword overlap score as fallback when no LLM is available.
|
|
||||||
Extracts meaningful tokens and computes Jaccard similarity.
|
|
||||||
"""
|
|
||||||
stopwords = {
|
stopwords = {
|
||||||
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
||||||
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
||||||
|
|
@ -458,35 +603,30 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
}
|
}
|
||||||
|
|
||||||
def tokenize(text: str) -> set:
|
def tokenize(text: str) -> set:
|
||||||
tokens = re.findall(r'\b[a-zA-Z0-9]+\b', text.lower())
|
tokens = re.findall(r'\b\w+\b', text.lower())
|
||||||
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
||||||
|
|
||||||
expected_tokens = tokenize(expected)
|
expected_tokens = tokenize(expected)
|
||||||
answer_tokens = tokenize(model_answer)
|
answer_tokens = tokenize(model_answer)
|
||||||
|
|
||||||
if not expected_tokens:
|
if not expected_tokens:
|
||||||
return 0.5 # Can't judge
|
return 0.5
|
||||||
|
|
||||||
overlap = len(expected_tokens & answer_tokens)
|
overlap = len(expected_tokens & answer_tokens)
|
||||||
union = len(expected_tokens | answer_tokens)
|
union = len(expected_tokens | answer_tokens)
|
||||||
|
|
||||||
jaccard = overlap / union if union > 0 else 0.0
|
jaccard = overlap / union if union > 0 else 0.0
|
||||||
# Recall-weighted: reward covering expected content
|
|
||||||
recall = overlap / len(expected_tokens)
|
recall = overlap / len(expected_tokens)
|
||||||
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_domains(text: str) -> set:
|
def _extract_domains(text: str) -> set:
|
||||||
"""
|
"""Extract unique domains from URLs cited in the response."""
|
||||||
Extract unique domains from URLs cited in the response.
|
|
||||||
Used to measure source diversity.
|
|
||||||
"""
|
|
||||||
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
||||||
domains = set()
|
domains = set()
|
||||||
for url in urls:
|
for url in urls:
|
||||||
try:
|
try:
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
# Normalize: strip www.
|
|
||||||
domain = parsed.netloc.lower().lstrip("www.")
|
domain = parsed.netloc.lower().lstrip("www.")
|
||||||
if domain:
|
if domain:
|
||||||
domains.add(domain)
|
domains.add(domain)
|
||||||
|
|
@ -494,20 +634,6 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
pass
|
pass
|
||||||
return domains
|
return domains
|
||||||
|
|
||||||
async def _run_agent_on_item(self, item: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Stub for running agent during eval. In Phase 1/2, this is handled
|
|
||||||
by the Atropos framework's rollout mechanism. Provided here for
|
|
||||||
standalone eval compatibility.
|
|
||||||
"""
|
|
||||||
# In real usage, the framework calls get_next_item + format_prompt
|
|
||||||
# and runs the agent. This stub returns an empty result for safety.
|
|
||||||
return {
|
|
||||||
"final_response": "",
|
|
||||||
"tools_used": [],
|
|
||||||
"tool_call_count": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Entry point
|
# Entry point
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue