Merge: fix double judge call + eval buffer pollution in WebResearchEnv
This commit is contained in:
commit
d5811c887a
1 changed files with 26 additions and 13 deletions
|
|
@ -475,14 +475,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
)
|
)
|
||||||
result = await agent.run(messages)
|
result = await agent.run(messages)
|
||||||
|
|
||||||
# Extract final response and compute reward
|
# Extract final response and tool usage from messages
|
||||||
ctx = ToolContext(task_id)
|
|
||||||
try:
|
|
||||||
reward = await self.compute_reward(item, result, ctx)
|
|
||||||
finally:
|
|
||||||
ctx.cleanup()
|
|
||||||
|
|
||||||
# Extract final response for logging
|
|
||||||
final_response = ""
|
final_response = ""
|
||||||
tool_call_count = 0
|
tool_call_count = 0
|
||||||
for msg in reversed(result.messages):
|
for msg in reversed(result.messages):
|
||||||
|
|
@ -491,12 +484,32 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
tool_call_count += len(msg["tool_calls"])
|
tool_call_count += len(msg["tool_calls"])
|
||||||
|
|
||||||
# Score correctness separately for the metric
|
# Compute reward (includes LLM judge for correctness)
|
||||||
correctness = await self._llm_judge(
|
# Temporarily save buffer lengths so we can extract the
|
||||||
question=item["question"],
|
# correctness score without calling judge twice, and avoid
|
||||||
expected=item["answer"],
|
# polluting training metric buffers with eval data.
|
||||||
model_answer=final_response,
|
buf_len = len(self._correctness_buffer)
|
||||||
|
ctx = ToolContext(task_id)
|
||||||
|
try:
|
||||||
|
reward = await self.compute_reward(item, result, ctx)
|
||||||
|
finally:
|
||||||
|
ctx.cleanup()
|
||||||
|
|
||||||
|
# Extract correctness from the buffer (compute_reward appended it)
|
||||||
|
# then remove eval entries from training buffers
|
||||||
|
correctness = (
|
||||||
|
self._correctness_buffer[buf_len]
|
||||||
|
if len(self._correctness_buffer) > buf_len
|
||||||
|
else 0.0
|
||||||
)
|
)
|
||||||
|
# Roll back buffers to avoid polluting training metrics
|
||||||
|
for buf in (
|
||||||
|
self._reward_buffer, self._correctness_buffer,
|
||||||
|
self._tool_usage_buffer, self._efficiency_buffer,
|
||||||
|
self._diversity_buffer,
|
||||||
|
):
|
||||||
|
if len(buf) > buf_len:
|
||||||
|
buf.pop()
|
||||||
|
|
||||||
samples.append({
|
samples.append({
|
||||||
"prompt": item["question"],
|
"prompt": item["question"],
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue