Merge: WebResearchEnv Atropos standards compliance

This commit is contained in:
teknium1 2026-03-09 17:45:57 -07:00
commit 8bc0d4f77d

View file

@ -16,21 +16,18 @@ Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
Usage: Usage:
# Phase 1 (OpenAI-compatible server) # Phase 1 (OpenAI-compatible server)
python environments/web_research_env.py serve \ python environments/web_research_env.py serve \\
--openai.base_url http://localhost:8000/v1 \ --openai.base_url http://localhost:8000/v1 \\
--openai.model_name YourModel \ --openai.model_name YourModel \\
--openai.server_type openai --openai.server_type openai
# With eval split # Process mode (offline data generation)
python environments/web_research_env.py serve \ python environments/web_research_env.py process \\
--openai.base_url http://localhost:8000/v1 \ --env.data_path_to_save_groups data/web_research.jsonl
--openai.model_name YourModel \
--env.eval_every 50 \
--env.eval_size 20
# Standalone eval (no training server needed) # Standalone eval
python environments/web_research_env.py eval \ python environments/web_research_env.py evaluate \\
--openai.base_url http://localhost:8000/v1 \ --openai.base_url http://localhost:8000/v1 \\
--openai.model_name YourModel --openai.model_name YourModel
Built by: github.com/jackx707 Built by: github.com/jackx707
@ -43,11 +40,21 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
import os
import random import random
import re import re
from typing import Any, Optional import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from pydantic import Field
# Ensure hermes-agent root is on path
_repo_root = Path(__file__).resolve().parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Optional HuggingFace datasets import # Optional HuggingFace datasets import
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -57,13 +64,19 @@ try:
except ImportError: except ImportError:
HF_AVAILABLE = False HF_AVAILABLE = False
from environments.hermes_base_env import HermesAgentBaseEnv from atroposlib.envs.base import ScoredDataGroup
from atroposlib.envs.server_handling.server_manager import APIServerConfig
from atroposlib.type_definitions import Item
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
from environments.agent_loop import AgentResult
from environments.tool_context import ToolContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Fallback sample dataset (used when HuggingFace is unavailable) # Fallback sample dataset (used when HuggingFace is unavailable)
# These are multi-hop questions that require real web search to answer. # Multi-hop questions requiring real web search to answer.
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
SAMPLE_QUESTIONS = [ SAMPLE_QUESTIONS = [
{ {
@ -129,6 +142,58 @@ SAMPLE_QUESTIONS = [
] ]
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
class WebResearchEnvConfig(HermesAgentEnvConfig):
"""Configuration for the web research RL environment."""
# Reward weights
correctness_weight: float = Field(
default=0.6,
description="Weight for answer correctness in reward (LLM judge score).",
)
tool_usage_weight: float = Field(
default=0.2,
description="Weight for tool usage signal (did the model actually use web tools?).",
)
efficiency_weight: float = Field(
default=0.2,
description="Weight for efficiency signal (penalizes excessive tool calls).",
)
diversity_bonus: float = Field(
default=0.1,
description="Bonus reward for citing ≥2 distinct domains.",
)
# Efficiency thresholds
efficient_max_calls: int = Field(
default=5,
description="Maximum tool calls before efficiency penalty begins.",
)
heavy_penalty_calls: int = Field(
default=10,
description="Tool call count where efficiency penalty steepens.",
)
# Eval
eval_size: int = Field(
default=20,
description="Number of held-out items for evaluation.",
)
eval_split_ratio: float = Field(
default=0.1,
description="Fraction of dataset to hold out for evaluation (0.01.0).",
)
# Dataset
dataset_name: str = Field(
default="google/frames-benchmark",
description="HuggingFace dataset name for research questions.",
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Environment # Environment
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -143,23 +208,60 @@ class WebResearchEnv(HermesAgentBaseEnv):
Reward is multi-signal: Reward is multi-signal:
60% answer correctness (LLM judge) 60% answer correctness (LLM judge)
20% tool usage (did the model actually search the web?) 20% tool usage (did the model actually search the web?)
20% efficiency (penalizes >6 tool calls) 20% efficiency (penalizes >5 tool calls)
Bonus +0.1 for source diversity (2 distinct domains cited). Bonus +0.1 for source diversity (2 distinct domains cited).
""" """
name = "web-research" name = "web-research"
env_config_cls = WebResearchEnvConfig
# Default toolsets for this environment — web + file for saving notes # Default toolsets for this environment — web + file for saving notes
default_toolsets = ["web", "file"] default_toolsets = ["web", "file"]
@classmethod
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
"""Default configuration for the web research environment."""
env_config = WebResearchEnvConfig(
enabled_toolsets=["web", "file"],
max_agent_turns=15,
agent_temperature=1.0,
system_prompt=(
"You are a highly capable research agent. When asked a factual question, "
"always use web_search to find current, accurate information before answering. "
"Cite at least 2 sources. Be concise and accurate."
),
group_size=4,
total_steps=1000,
steps_per_eval=100,
use_wandb=True,
wandb_name="web-research",
)
server_configs = [
APIServerConfig(
base_url="https://openrouter.ai/api/v1",
model_name="anthropic/claude-sonnet-4.5",
server_type="openai",
api_key=os.getenv("OPENROUTER_API_KEY", ""),
health_check=False,
)
]
return env_config, server_configs
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._items: list[dict] = [] self._items: list[dict] = []
self._eval_items: list[dict] = [] self._eval_items: list[dict] = []
self._index: int = 0 self._index: int = 0
self._total_scored: int = 0
self._total_reward: float = 0.0 # Metrics tracking for wandb
self._reward_buffer: list[float] = []
self._correctness_buffer: list[float] = []
self._tool_usage_buffer: list[float] = []
self._efficiency_buffer: list[float] = []
self._diversity_buffer: list[float] = []
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# 1. Setup — load dataset # 1. Setup — load dataset
@ -170,7 +272,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
if HF_AVAILABLE: if HF_AVAILABLE:
try: try:
logger.info("Loading FRAMES benchmark from HuggingFace...") logger.info("Loading FRAMES benchmark from HuggingFace...")
ds = load_dataset("google/frames-benchmark", split="test") ds = load_dataset(self.config.dataset_name, split="test")
self._items = [ self._items = [
{ {
"question": row["Prompt"], "question": row["Prompt"],
@ -180,8 +282,11 @@ class WebResearchEnv(HermesAgentBaseEnv):
} }
for row in ds for row in ds
] ]
# Hold out 10% for eval # Hold out for eval
eval_size = max(20, len(self._items) // 10) eval_size = max(
self.config.eval_size,
int(len(self._items) * self.config.eval_split_ratio),
)
random.shuffle(self._items) random.shuffle(self._items)
self._eval_items = self._items[:eval_size] self._eval_items = self._items[:eval_size]
self._items = self._items[eval_size:] self._items = self._items[eval_size:]
@ -220,10 +325,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def format_prompt(self, item: dict) -> str: def format_prompt(self, item: dict) -> str:
""" """Format the research question as a task prompt."""
Format the research question as a task prompt.
Instructs the model to use web search and cite sources.
"""
return ( return (
f"Research the following question thoroughly using web search. " f"Research the following question thoroughly using web search. "
f"You MUST search the web to find current, accurate information — " f"You MUST search the web to find current, accurate information — "
@ -243,27 +345,30 @@ class WebResearchEnv(HermesAgentBaseEnv):
async def compute_reward( async def compute_reward(
self, self,
item: dict, item: dict,
result: dict, result: AgentResult,
ctx: Any, # ToolContext ctx: ToolContext,
) -> float: ) -> float:
""" """
Multi-signal reward function: Multi-signal reward function:
0.6 * correctness LLM judge comparing answer to ground truth correctness_weight * correctness LLM judge comparing answer to ground truth
0.2 * tool_used binary: did the model use web tools? tool_usage_weight * tool_used binary: did the model use web tools?
0.2 * efficiency penalizes wasteful tool usage efficiency_weight * efficiency penalizes wasteful tool usage
+0.1 bonus source diversity (2 distinct domains) + diversity_bonus source diversity (2 distinct domains)
""" """
final_response: str = result.get("final_response", "") final_response: str = result.final_response or ""
tools_used: list[str] = result.get("tools_used", []) tools_used: list[str] = [
tool_call_count: int = result.get("tool_call_count", len(tools_used)) tc.tool_name for tc in (result.tool_calls or [])
] if hasattr(result, "tool_calls") and result.tool_calls else []
tool_call_count: int = result.turns_used or len(tools_used)
cfg = self.config
# ---- Signal 1: Answer correctness (LLM judge) ---------------- # ---- Signal 1: Answer correctness (LLM judge) ----------------
correctness = await self._llm_judge( correctness = await self._llm_judge(
question=item["question"], question=item["question"],
expected=item["answer"], expected=item["answer"],
model_answer=final_response, model_answer=final_response,
ctx=ctx,
) )
# ---- Signal 2: Web tool usage -------------------------------- # ---- Signal 2: Web tool usage --------------------------------
@ -271,35 +376,37 @@ class WebResearchEnv(HermesAgentBaseEnv):
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0 tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
# ---- Signal 3: Efficiency ------------------------------------ # ---- Signal 3: Efficiency ------------------------------------
# Ideal: 2-5 tool calls. Penalise beyond 6, hard cap at 15. if tool_call_count <= cfg.efficient_max_calls:
if tool_call_count <= 5:
efficiency = 1.0 efficiency = 1.0
elif tool_call_count <= 10: elif tool_call_count <= cfg.heavy_penalty_calls:
efficiency = 1.0 - (tool_call_count - 5) * 0.08 efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
else: else:
efficiency = max(0.0, 1.0 - (tool_call_count - 5) * 0.12) efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
# ---- Bonus: Source diversity --------------------------------- # ---- Bonus: Source diversity ---------------------------------
domains = self._extract_domains(final_response) domains = self._extract_domains(final_response)
diversity_bonus = 0.1 if len(domains) >= 2 else 0.0 diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
# ---- Combine ------------------------------------------------ # ---- Combine ------------------------------------------------
reward = ( reward = (
0.6 * correctness cfg.correctness_weight * correctness
+ 0.2 * tool_used + cfg.tool_usage_weight * tool_used
+ 0.2 * efficiency + cfg.efficiency_weight * efficiency
+ diversity_bonus + diversity
) )
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1] reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
# Track running stats # Track for wandb
self._total_scored += 1 self._reward_buffer.append(reward)
self._total_reward += reward self._correctness_buffer.append(correctness)
self._tool_usage_buffer.append(tool_used)
self._efficiency_buffer.append(efficiency)
self._diversity_buffer.append(diversity)
logger.debug( logger.debug(
f"Reward breakdown — correctness={correctness:.2f}, " f"Reward breakdown — correctness={correctness:.2f}, "
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, " f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
f"diversity_bonus={diversity_bonus:.1f} → total={reward:.3f}" f"diversity={diversity:.1f} → total={reward:.3f}"
) )
return reward return reward
@ -308,68 +415,117 @@ class WebResearchEnv(HermesAgentBaseEnv):
# 5. evaluate — run on held-out eval split # 5. evaluate — run on held-out eval split
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def evaluate( async def evaluate(self, *args, **kwargs) -> None:
self, """Run evaluation on the held-out split using the agent loop."""
*args: Any, import time
eval_size: Optional[int] = None,
**kwargs: Any,
) -> dict:
"""
Run evaluation on the held-out split.
Returns a dict of metrics for logging.
"""
items = self._eval_items
if eval_size:
items = items[:eval_size]
items = self._eval_items
if not items: if not items:
logger.warning("No eval items available.") logger.warning("No eval items available.")
return {} return
logger.info(f"Running eval on {len(items)} questions...") eval_size = min(self.config.eval_size, len(items))
eval_items = items[:eval_size]
rewards = [] logger.info(f"Running eval on {len(eval_items)} questions...")
correctness_scores = [] start_time = time.time()
samples = []
for item in items: for item in eval_items:
try: try:
# Run the agent on each eval question # Use the base env's agent loop for eval (same as training)
result = await self._run_agent_on_item(item) prompt = self.format_prompt(item)
reward = await self.compute_reward(item, result, ctx=None) completion = await self.server.chat_completion(
rewards.append(reward) messages=[
{"role": "system", "content": self.config.system_prompt or ""},
{"role": "user", "content": prompt},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
split="eval",
)
response_content = (
completion.choices[0].message.content if completion.choices else ""
)
# Score the response
correctness = await self._llm_judge(
question=item["question"],
expected=item["answer"],
model_answer=response_content,
)
samples.append({
"prompt": item["question"],
"response": response_content,
"expected": item["answer"],
"correctness": correctness,
})
# Also track raw correctness separately
if result.get("final_response"):
correctness_scores.append(
await self._llm_judge(
question=item["question"],
expected=item["answer"],
model_answer=result["final_response"],
ctx=None,
)
)
except Exception as e: except Exception as e:
logger.error(f"Eval error on item: {e}") logger.error(f"Eval error on item: {e}")
rewards.append(0.0) samples.append({
"prompt": item["question"],
"response": f"ERROR: {e}",
"expected": item["answer"],
"correctness": 0.0,
})
metrics = { end_time = time.time()
"eval/mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
# Compute metrics
correctness_scores = [s["correctness"] for s in samples]
eval_metrics = {
"eval/mean_correctness": ( "eval/mean_correctness": (
sum(correctness_scores) / len(correctness_scores) sum(correctness_scores) / len(correctness_scores)
if correctness_scores else 0.0 if correctness_scores else 0.0
), ),
"eval/n_items": len(rewards), "eval/n_items": len(samples),
"train/mean_reward_so_far": (
self._total_reward / self._total_scored
if self._total_scored > 0 else 0.0
),
} }
logger.info( await self.evaluate_log(
f"Eval complete — mean_reward={metrics['eval/mean_reward']:.3f}, " metrics=eval_metrics,
f"mean_correctness={metrics['eval/mean_correctness']:.3f}" samples=samples,
start_time=start_time,
end_time=end_time,
) )
return metrics
# ------------------------------------------------------------------
# 6. wandb_log — custom metrics
# ------------------------------------------------------------------
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
"""Log reward breakdown metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
if self._reward_buffer:
n = len(self._reward_buffer)
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
wandb_metrics["train/total_rollouts"] = n
# Accuracy buckets
wandb_metrics["train/correct_rate"] = (
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
)
wandb_metrics["train/tool_usage_rate"] = (
sum(1 for t in self._tool_usage_buffer if t > 0) / n
)
# Clear buffers
self._reward_buffer.clear()
self._correctness_buffer.clear()
self._tool_usage_buffer.clear()
self._efficiency_buffer.clear()
self._diversity_buffer.clear()
await super().wandb_log(wandb_metrics)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Private helpers # Private helpers
@ -380,19 +536,14 @@ class WebResearchEnv(HermesAgentBaseEnv):
question: str, question: str,
expected: str, expected: str,
model_answer: str, model_answer: str,
ctx: Any,
) -> float: ) -> float:
""" """
Use an LLM to judge whether `model_answer` correctly addresses Use the server's LLM to judge answer correctness.
`question` compared to `expected`. Returns a float in [0, 1]. Falls back to keyword heuristic if LLM call fails.
Uses the agent's own inference client if ctx is available,
otherwise falls back to a lightweight heuristic.
""" """
if not model_answer or not model_answer.strip(): if not model_answer or not model_answer.strip():
return 0.0 return 0.0
# Build judge prompt
judge_prompt = ( judge_prompt = (
"You are an impartial judge evaluating the quality of an AI research answer.\n\n" "You are an impartial judge evaluating the quality of an AI research answer.\n\n"
f"Question: {question}\n\n" f"Question: {question}\n\n"
@ -405,39 +556,36 @@ class WebResearchEnv(HermesAgentBaseEnv):
" 0.1 = mentions relevant topic but wrong or very incomplete\n" " 0.1 = mentions relevant topic but wrong or very incomplete\n"
" 0.0 = completely wrong or no answer\n\n" " 0.0 = completely wrong or no answer\n\n"
"Consider: factual accuracy, completeness, and relevance.\n" "Consider: factual accuracy, completeness, and relevance.\n"
"Respond with ONLY a JSON object: {\"score\": <float>, \"reason\": \"<one sentence>\"}" 'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
) )
# Try using ctx for inference (Phase 2 / live training) try:
if ctx is not None and hasattr(ctx, "chat_completion"): response = await self.server.chat_completion(
try: messages=[{"role": "user", "content": judge_prompt}],
response = await ctx.chat_completion( n=1,
messages=[{"role": "user", "content": judge_prompt}], max_tokens=150,
max_tokens=100, temperature=0.0,
temperature=0.0, split="eval",
) )
text = response.get("content", "") text = response.choices[0].message.content if response.choices else ""
parsed = self._parse_judge_json(text) parsed = self._parse_judge_json(text)
if parsed is not None: if parsed is not None:
return float(parsed) return float(parsed)
except Exception as e: except Exception as e:
logger.debug(f"LLM judge via ctx failed: {e}. Using heuristic.") logger.debug(f"LLM judge failed: {e}. Using heuristic.")
# Fallback: keyword overlap heuristic
return self._heuristic_score(expected, model_answer) return self._heuristic_score(expected, model_answer)
@staticmethod @staticmethod
def _parse_judge_json(text: str) -> Optional[float]: def _parse_judge_json(text: str) -> Optional[float]:
"""Extract the score float from LLM judge JSON response.""" """Extract the score float from LLM judge JSON response."""
try: try:
# Strip markdown code fences if present
clean = re.sub(r"```(?:json)?|```", "", text).strip() clean = re.sub(r"```(?:json)?|```", "", text).strip()
data = json.loads(clean) data = json.loads(clean)
score = float(data.get("score", -1)) score = float(data.get("score", -1))
if 0.0 <= score <= 1.0: if 0.0 <= score <= 1.0:
return score return score
except Exception: except Exception:
# Try regex fallback
match = re.search(r'"score"\s*:\s*([0-9.]+)', text) match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
if match: if match:
score = float(match.group(1)) score = float(match.group(1))
@ -447,10 +595,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
@staticmethod @staticmethod
def _heuristic_score(expected: str, model_answer: str) -> float: def _heuristic_score(expected: str, model_answer: str) -> float:
""" """Lightweight keyword overlap score as fallback."""
Lightweight keyword overlap score as fallback when no LLM is available.
Extracts meaningful tokens and computes Jaccard similarity.
"""
stopwords = { stopwords = {
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on", "the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
"at", "to", "for", "with", "and", "or", "but", "it", "its", "at", "to", "for", "with", "and", "or", "but", "it", "its",
@ -458,35 +603,30 @@ class WebResearchEnv(HermesAgentBaseEnv):
} }
def tokenize(text: str) -> set: def tokenize(text: str) -> set:
tokens = re.findall(r'\b[a-zA-Z0-9]+\b', text.lower()) tokens = re.findall(r'\b\w+\b', text.lower())
return {t for t in tokens if t not in stopwords and len(t) > 2} return {t for t in tokens if t not in stopwords and len(t) > 2}
expected_tokens = tokenize(expected) expected_tokens = tokenize(expected)
answer_tokens = tokenize(model_answer) answer_tokens = tokenize(model_answer)
if not expected_tokens: if not expected_tokens:
return 0.5 # Can't judge return 0.5
overlap = len(expected_tokens & answer_tokens) overlap = len(expected_tokens & answer_tokens)
union = len(expected_tokens | answer_tokens) union = len(expected_tokens | answer_tokens)
jaccard = overlap / union if union > 0 else 0.0 jaccard = overlap / union if union > 0 else 0.0
# Recall-weighted: reward covering expected content
recall = overlap / len(expected_tokens) recall = overlap / len(expected_tokens)
return min(1.0, 0.4 * jaccard + 0.6 * recall) return min(1.0, 0.4 * jaccard + 0.6 * recall)
@staticmethod @staticmethod
def _extract_domains(text: str) -> set: def _extract_domains(text: str) -> set:
""" """Extract unique domains from URLs cited in the response."""
Extract unique domains from URLs cited in the response.
Used to measure source diversity.
"""
urls = re.findall(r'https?://[^\s\)>\]"\']+', text) urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
domains = set() domains = set()
for url in urls: for url in urls:
try: try:
parsed = urlparse(url) parsed = urlparse(url)
# Normalize: strip www.
domain = parsed.netloc.lower().lstrip("www.") domain = parsed.netloc.lower().lstrip("www.")
if domain: if domain:
domains.add(domain) domains.add(domain)
@ -494,20 +634,6 @@ class WebResearchEnv(HermesAgentBaseEnv):
pass pass
return domains return domains
async def _run_agent_on_item(self, item: dict) -> dict:
"""
Stub for running agent during eval. In Phase 1/2, this is handled
by the Atropos framework's rollout mechanism. Provided here for
standalone eval compatibility.
"""
# In real usage, the framework calls get_next_item + format_prompt
# and runs the agent. This stub returns an empty result for safety.
return {
"final_response": "",
"tools_used": [],
"tool_call_count": 0,
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Entry point # Entry point