Implement enhanced response handling and tool call validation in run_agent
- Added methods to check for meaningful content after <think> blocks and to retrieve messages up to the last complete assistant turn. - Introduced retry logic for handling truncated responses and invalid JSON arguments in tool calls, with a maximum retry limit. - Improved logging for invalid JSON and empty responses, ensuring better error tracking and handling. - Updated the batch data generation script to adjust dataset file, batch size, and ephemeral system prompt for improved context management.
This commit is contained in:
parent
4071ba29da
commit
66daebe88f
2 changed files with 194 additions and 7 deletions
191
run_agent.py
191
run_agent.py
|
|
@ -208,6 +208,60 @@ class AIAgent:
|
|||
prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
|
||||
print(f"🔒 Ephemeral system prompt: '{prompt_preview}' (not saved to trajectories)")
|
||||
|
||||
def _has_content_after_think_block(self, content: str) -> bool:
|
||||
"""
|
||||
Check if content has actual text after any <think></think> blocks.
|
||||
|
||||
This detects cases where the model only outputs reasoning but no actual
|
||||
response, which indicates an incomplete generation that should be retried.
|
||||
|
||||
Args:
|
||||
content: The assistant message content to check
|
||||
|
||||
Returns:
|
||||
True if there's meaningful content after think blocks, False otherwise
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
import re
|
||||
# Remove all <think>...</think> blocks (including nested ones, non-greedy)
|
||||
cleaned = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL)
|
||||
|
||||
# Check if there's any non-whitespace content remaining
|
||||
return bool(cleaned.strip())
|
||||
|
||||
def _get_messages_up_to_last_assistant(self, messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Get messages up to (but not including) the last assistant turn.
|
||||
|
||||
This is used when we need to "roll back" to the last successful point
|
||||
in the conversation, typically when the final assistant message is
|
||||
incomplete or malformed.
|
||||
|
||||
Args:
|
||||
messages: Full message list
|
||||
|
||||
Returns:
|
||||
Messages up to the last complete assistant turn (ending with user/tool message)
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Find the index of the last assistant message
|
||||
last_assistant_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i].get("role") == "assistant":
|
||||
last_assistant_idx = i
|
||||
break
|
||||
|
||||
if last_assistant_idx is None:
|
||||
# No assistant message found, return all messages
|
||||
return messages.copy()
|
||||
|
||||
# Return everything up to (not including) the last assistant message
|
||||
return messages[:last_assistant_idx]
|
||||
|
||||
def _format_tools_for_system_message(self) -> str:
|
||||
"""
|
||||
Format tool definitions for the system message in the trajectory format.
|
||||
|
|
@ -292,9 +346,19 @@ class AIAgent:
|
|||
|
||||
# Add tool calls wrapped in XML tags
|
||||
for tool_call in msg["tool_calls"]:
|
||||
# Parse arguments - should always succeed since we validate during conversation
|
||||
# but keep try-except as safety net
|
||||
try:
|
||||
arguments = json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"]
|
||||
except json.JSONDecodeError:
|
||||
# This shouldn't happen since we validate and retry during conversation,
|
||||
# but if it does, log warning and use empty dict
|
||||
logging.warning(f"Unexpected invalid JSON in trajectory conversion: {tool_call['function']['arguments'][:100]}")
|
||||
arguments = {}
|
||||
|
||||
tool_call_json = {
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"]
|
||||
"arguments": arguments
|
||||
}
|
||||
content += f"<tool_call>\n{json.dumps(tool_call_json, ensure_ascii=False)}\n</tool_call>\n"
|
||||
|
||||
|
|
@ -417,6 +481,12 @@ class AIAgent:
|
|||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||
import uuid
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
# Reset retry counters at the start of each conversation to prevent state leakage
|
||||
self._invalid_tool_retries = 0
|
||||
self._invalid_json_retries = 0
|
||||
self._empty_content_retries = 0
|
||||
|
||||
# Initialize conversation
|
||||
messages = conversation_history or []
|
||||
|
||||
|
|
@ -540,6 +610,45 @@ class AIAgent:
|
|||
time.sleep(wait_time)
|
||||
continue # Retry the API call
|
||||
|
||||
# Check finish_reason before proceeding
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
|
||||
# Handle "length" finish_reason - response was truncated
|
||||
if finish_reason == "length":
|
||||
print(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens")
|
||||
|
||||
# If we have prior messages, roll back to last complete state
|
||||
if len(messages) > 1:
|
||||
print(f"{self.log_prefix} ⏪ Rolling back to last complete assistant turn")
|
||||
rolled_back_messages = self._get_messages_up_to_last_assistant(messages)
|
||||
|
||||
# Clean up VM
|
||||
try:
|
||||
cleanup_vm(effective_task_id)
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
|
||||
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": rolled_back_messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"partial": True,
|
||||
"error": "Response truncated due to output length limit"
|
||||
}
|
||||
else:
|
||||
# First message was truncated - mark as failed
|
||||
print(f"{self.log_prefix}❌ First response truncated - cannot recover")
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"failed": True,
|
||||
"error": "First response truncated due to output length limit"
|
||||
}
|
||||
|
||||
break # Success, exit retry loop
|
||||
|
||||
except Exception as api_error:
|
||||
|
|
@ -638,6 +747,40 @@ class AIAgent:
|
|||
if hasattr(self, '_invalid_tool_retries'):
|
||||
self._invalid_tool_retries = 0
|
||||
|
||||
# Validate tool call arguments are valid JSON
|
||||
invalid_json_args = []
|
||||
for tc in assistant_message.tool_calls:
|
||||
try:
|
||||
json.loads(tc.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
invalid_json_args.append((tc.function.name, str(e)))
|
||||
|
||||
if invalid_json_args:
|
||||
# Track retries for invalid JSON arguments
|
||||
self._invalid_json_retries += 1
|
||||
|
||||
tool_name, error_msg = invalid_json_args[0]
|
||||
print(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}")
|
||||
|
||||
if self._invalid_json_retries < 3:
|
||||
print(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...")
|
||||
# Don't add anything to messages, just retry the API call
|
||||
continue
|
||||
else:
|
||||
print(f"{self.log_prefix}❌ Max retries (3) for invalid JSON arguments exceeded. Stopping as partial.")
|
||||
self._invalid_json_retries = 0 # Reset for next conversation
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": messages, # Messages up to last valid point
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"partial": True,
|
||||
"error": f"Model generated invalid JSON arguments for tool '{tool_name}': {error_msg}"
|
||||
}
|
||||
|
||||
# Reset retry counter on successful JSON validation
|
||||
self._invalid_json_retries = 0
|
||||
|
||||
# Extract reasoning from response if available (for reasoning models like minimax, kimi, etc.)
|
||||
reasoning_content = None
|
||||
if hasattr(assistant_message, 'reasoning') and assistant_message.reasoning:
|
||||
|
|
@ -667,10 +810,12 @@ class AIAgent:
|
|||
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
|
||||
function_name = tool_call.function.name
|
||||
|
||||
# Parse arguments - should always succeed since we validated above
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"❌ Invalid JSON in tool call arguments: {e}")
|
||||
# This shouldn't happen since we validate and retry above
|
||||
logging.warning(f"Unexpected JSON error after validation: {e}")
|
||||
function_args = {}
|
||||
|
||||
# Preview tool call arguments
|
||||
|
|
@ -712,6 +857,48 @@ class AIAgent:
|
|||
# No tool calls - this is the final response
|
||||
final_response = assistant_message.content or ""
|
||||
|
||||
# Check if response only has think block with no actual content after it
|
||||
if not self._has_content_after_think_block(final_response):
|
||||
# Track retries for empty-after-think responses
|
||||
if not hasattr(self, '_empty_content_retries'):
|
||||
self._empty_content_retries = 0
|
||||
self._empty_content_retries += 1
|
||||
|
||||
content_preview = final_response[:80] + "..." if len(final_response) > 80 else final_response
|
||||
print(f"{self.log_prefix}⚠️ Response only contains think block with no content after it")
|
||||
print(f"{self.log_prefix} Content: '{content_preview}'")
|
||||
|
||||
if self._empty_content_retries < 3:
|
||||
print(f"{self.log_prefix}🔄 Retrying API call ({self._empty_content_retries}/3)...")
|
||||
# Don't add the incomplete message, just retry
|
||||
continue
|
||||
else:
|
||||
# Max retries exceeded - roll back to last complete assistant turn
|
||||
print(f"{self.log_prefix}❌ Max retries (3) for empty content exceeded. Rolling back to last complete turn.")
|
||||
self._empty_content_retries = 0 # Reset for next conversation
|
||||
|
||||
rolled_back_messages = self._get_messages_up_to_last_assistant(messages)
|
||||
|
||||
# Clean up VM
|
||||
try:
|
||||
cleanup_vm(effective_task_id)
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
|
||||
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": rolled_back_messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"partial": True,
|
||||
"error": "Model generated only think blocks with no actual response after 3 retries"
|
||||
}
|
||||
|
||||
# Reset retry counter on successful content
|
||||
if hasattr(self, '_empty_content_retries'):
|
||||
self._empty_content_retries = 0
|
||||
|
||||
# Extract reasoning from response if available
|
||||
reasoning_content = None
|
||||
if hasattr(assistant_message, 'reasoning') and assistant_message.reasoning:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue