refactor: streamline scratchpad handling in AIAgent
- Removed static methods for converting and checking <REASONING_SCRATCHPAD> tags, simplifying the codebase. - Replaced calls to the removed methods with direct function calls for better clarity and maintainability. - Updated trajectory saving logic to utilize a dedicated function for improved organization and readability.
This commit is contained in:
parent
8fedbf87d9
commit
d18c753b3c
1 changed files with 4 additions and 60 deletions
64
run_agent.py
64
run_agent.py
|
|
@ -637,43 +637,6 @@ class AIAgent:
|
||||||
|
|
||||||
return json.dumps(formatted_tools, ensure_ascii=False)
|
return json.dumps(formatted_tools, ensure_ascii=False)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _convert_scratchpad_to_think(content: str) -> str:
|
|
||||||
"""
|
|
||||||
Convert <REASONING_SCRATCHPAD> tags to <think> tags in content.
|
|
||||||
|
|
||||||
When native thinking/reasoning is disabled and the model is prompted to
|
|
||||||
reason inside <REASONING_SCRATCHPAD> XML tags instead, this converts those
|
|
||||||
to the standard <think> format used in our trajectory storage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Assistant message content that may contain scratchpad tags
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Content with scratchpad tags replaced by think tags
|
|
||||||
"""
|
|
||||||
if not content or "<REASONING_SCRATCHPAD>" not in content:
|
|
||||||
return content
|
|
||||||
return content.replace("<REASONING_SCRATCHPAD>", "<think>").replace("</REASONING_SCRATCHPAD>", "</think>")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _has_incomplete_scratchpad(content: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if content has an opening <REASONING_SCRATCHPAD> without a closing tag.
|
|
||||||
|
|
||||||
This indicates the model ran out of output tokens mid-reasoning, producing
|
|
||||||
a broken turn that shouldn't be saved. The caller should retry or discard.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Assistant message content to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if there's an unclosed scratchpad tag
|
|
||||||
"""
|
|
||||||
if not content:
|
|
||||||
return False
|
|
||||||
return "<REASONING_SCRATCHPAD>" in content and "</REASONING_SCRATCHPAD>" not in content
|
|
||||||
|
|
||||||
def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]:
|
def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Convert internal message format to trajectory format for saving.
|
Convert internal message format to trajectory format for saving.
|
||||||
|
|
@ -738,7 +701,7 @@ class AIAgent:
|
||||||
if msg.get("content") and msg["content"].strip():
|
if msg.get("content") and msg["content"].strip():
|
||||||
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
|
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
|
||||||
# (used when native thinking is disabled and model reasons via XML)
|
# (used when native thinking is disabled and model reasons via XML)
|
||||||
content += self._convert_scratchpad_to_think(msg["content"]) + "\n"
|
content += convert_scratchpad_to_think(msg["content"]) + "\n"
|
||||||
|
|
||||||
# Add tool calls wrapped in XML tags
|
# Add tool calls wrapped in XML tags
|
||||||
for tool_call in msg["tool_calls"]:
|
for tool_call in msg["tool_calls"]:
|
||||||
|
|
@ -813,7 +776,7 @@ class AIAgent:
|
||||||
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
|
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
|
||||||
# (used when native thinking is disabled and model reasons via XML)
|
# (used when native thinking is disabled and model reasons via XML)
|
||||||
raw_content = msg["content"] or ""
|
raw_content = msg["content"] or ""
|
||||||
content += self._convert_scratchpad_to_think(raw_content)
|
content += convert_scratchpad_to_think(raw_content)
|
||||||
|
|
||||||
# Ensure every gpt turn has a <think> block (empty if no reasoning)
|
# Ensure every gpt turn has a <think> block (empty if no reasoning)
|
||||||
if "<think>" not in content:
|
if "<think>" not in content:
|
||||||
|
|
@ -846,27 +809,8 @@ class AIAgent:
|
||||||
if not self.save_trajectories:
|
if not self.save_trajectories:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Convert messages to trajectory format
|
|
||||||
trajectory = self._convert_to_trajectory_format(messages, user_query, completed)
|
trajectory = self._convert_to_trajectory_format(messages, user_query, completed)
|
||||||
|
_save_trajectory_to_file(trajectory, self.model, completed)
|
||||||
# Determine which file to save to
|
|
||||||
filename = "trajectory_samples.jsonl" if completed else "failed_trajectories.jsonl"
|
|
||||||
|
|
||||||
# Create trajectory entry
|
|
||||||
entry = {
|
|
||||||
"conversations": trajectory,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"model": self.model,
|
|
||||||
"completed": completed
|
|
||||||
}
|
|
||||||
|
|
||||||
# Append to JSONL file
|
|
||||||
try:
|
|
||||||
with open(filename, "a", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
|
||||||
logger.info("Trajectory saved to %s", filename)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to save trajectory: %s", e)
|
|
||||||
|
|
||||||
def _mask_api_key_for_logs(self, key: Optional[str]) -> Optional[str]:
|
def _mask_api_key_for_logs(self, key: Optional[str]) -> Optional[str]:
|
||||||
if not key:
|
if not key:
|
||||||
|
|
@ -2134,7 +2078,7 @@ class AIAgent:
|
||||||
|
|
||||||
# Check for incomplete <REASONING_SCRATCHPAD> (opened but never closed)
|
# Check for incomplete <REASONING_SCRATCHPAD> (opened but never closed)
|
||||||
# This means the model ran out of output tokens mid-reasoning — retry up to 2 times
|
# This means the model ran out of output tokens mid-reasoning — retry up to 2 times
|
||||||
if self._has_incomplete_scratchpad(assistant_message.content or ""):
|
if has_incomplete_scratchpad(assistant_message.content or ""):
|
||||||
if not hasattr(self, '_incomplete_scratchpad_retries'):
|
if not hasattr(self, '_incomplete_scratchpad_retries'):
|
||||||
self._incomplete_scratchpad_retries = 0
|
self._incomplete_scratchpad_retries = 0
|
||||||
self._incomplete_scratchpad_retries += 1
|
self._incomplete_scratchpad_retries += 1
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue