Enhance BatchRunner and AIAgent with new configuration options, default model now opus 4.6, default summarizer gemini flash 3

- Added `max_tokens`, `reasoning_config`, and `prefill_messages` parameters to `BatchRunner` and `AIAgent` for improved model response control.
- Updated CLI to support new options for reasoning effort and prefill messages from a JSON file.
- Modified example configuration files to reflect changes in default model and summary model.
- Improved error handling for loading prefill messages and reasoning configurations in the CLI.
- Updated documentation to include new parameters and usage examples.
This commit is contained in:
teknium 2026-02-08 10:49:24 +00:00
parent fa76a331b0
commit f12ea1bc02
7 changed files with 324 additions and 40 deletions

View file

@ -66,6 +66,7 @@ _MODEL_CACHE_TTL = 3600 # 1 hour cache TTL
DEFAULT_CONTEXT_LENGTHS = {
"anthropic/claude-opus-4": 200000,
"anthropic/claude-opus-4.5": 200000,
"anthropic/claude-opus-4.6": 200000,
"anthropic/claude-sonnet-4": 200000,
"anthropic/claude-sonnet-4-20250514": 200000,
"anthropic/claude-haiku-4.5": 200000,
@ -206,7 +207,7 @@ class ContextCompressor:
self,
model: str,
threshold_percent: float = 0.85,
summary_model: str = "google/gemini-2.0-flash-001",
summary_model: str = "google/gemini-3-flash-preview",
protect_first_n: int = 3,
protect_last_n: int = 4,
summary_target_tokens: int = 500,
@ -584,7 +585,7 @@ class AIAgent:
self,
base_url: str = None,
api_key: str = None,
model: str = "anthropic/claude-sonnet-4-20250514", # OpenRouter format
model: str = "anthropic/claude-opus-4.6", # OpenRouter format
max_iterations: int = 60, # Default tool-calling iterations
tool_delay: float = 1.0,
enabled_toolsets: List[str] = None,
@ -601,6 +602,9 @@ class AIAgent:
provider_sort: str = None,
session_id: str = None,
tool_progress_callback: callable = None,
max_tokens: int = None,
reasoning_config: Dict[str, Any] = None,
prefill_messages: List[Dict[str, Any]] = None,
):
"""
Initialize the AI Agent.
@ -625,6 +629,12 @@ class AIAgent:
provider_sort (str): Sort providers by price/throughput/latency (optional)
session_id (str): Pre-generated session ID for logging (optional, auto-generated if not provided)
tool_progress_callback (callable): Callback function(tool_name, args_preview) for progress notifications
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
reasoning_config (Dict): OpenRouter reasoning configuration override (e.g. {"effort": "none"} to disable thinking).
If None, defaults to {"enabled": True, "effort": "xhigh"} for OpenRouter. Set to disable/customize reasoning.
prefill_messages (List[Dict]): Messages to prepend to conversation history as prefilled context.
Useful for injecting a few-shot example or priming the model's response style.
Example: [{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "Hello!"}]
"""
self.model = model
self.max_iterations = max_iterations
@ -653,6 +663,11 @@ class AIAgent:
self.enabled_toolsets = enabled_toolsets
self.disabled_toolsets = disabled_toolsets
# Model response configuration
self.max_tokens = max_tokens # None = use model default
self.reasoning_config = reasoning_config # None = use default (xhigh for OpenRouter)
self.prefill_messages = prefill_messages or [] # Prefilled conversation turns
# Configure logging
if self.verbose_logging:
logging.basicConfig(
@ -781,7 +796,7 @@ class AIAgent:
# Compresses conversation when approaching model's context limit
# Configuration via environment variables (can be set in .env or cli-config.yaml)
compression_threshold = float(os.getenv("CONTEXT_COMPRESSION_THRESHOLD", "0.85"))
compression_model = os.getenv("CONTEXT_COMPRESSION_MODEL", "google/gemini-2.0-flash-001")
compression_model = os.getenv("CONTEXT_COMPRESSION_MODEL", "google/gemini-3-flash-preview")
compression_enabled = os.getenv("CONTEXT_COMPRESSION_ENABLED", "true").lower() in ("true", "1", "yes")
self.context_compressor = ContextCompressor(
@ -1086,6 +1101,25 @@ class AIAgent:
return json.dumps(formatted_tools, ensure_ascii=False)
@staticmethod
def _convert_scratchpad_to_think(content: str) -> str:
"""
Convert <REASONING_SCRATCHPAD> tags to <think> tags in content.
When native thinking/reasoning is disabled and the model is prompted to
reason inside <REASONING_SCRATCHPAD> XML tags instead, this converts those
to the standard <think> format used in our trajectory storage.
Args:
content: Assistant message content that may contain scratchpad tags
Returns:
Content with scratchpad tags replaced by think tags
"""
if not content or "<REASONING_SCRATCHPAD>" not in content:
return content
return content.replace("<REASONING_SCRATCHPAD>", "<think>").replace("</REASONING_SCRATCHPAD>", "</think>")
def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]:
"""
Convert internal message format to trajectory format for saving.
@ -1120,14 +1154,19 @@ class AIAgent:
"value": system_msg
})
# Add the initial user message
# Add the actual user prompt (from the dataset) as the first human message
trajectory.append({
"from": "human",
"value": user_query
})
# Process remaining messages
i = 1 # Skip the first user message as we already added it
# Calculate where agent responses start in the messages list.
# Prefill messages are ephemeral (only used to prime model response style)
# so we skip them entirely in the saved trajectory.
# Layout: [*prefill_msgs, actual_user_msg, ...agent_responses...]
num_prefill = len(self.prefill_messages) if self.prefill_messages else 0
i = num_prefill + 1 # Skip prefill messages + the actual user message (already added above)
while i < len(messages):
msg = messages[i]
@ -1138,12 +1177,14 @@ class AIAgent:
# Add <think> tags around reasoning for trajectory storage
content = ""
# Prepend reasoning in <think> tags if available
# Prepend reasoning in <think> tags if available (native thinking tokens)
if msg.get("reasoning") and msg["reasoning"].strip():
content = f"<think>\n{msg['reasoning']}\n</think>\n"
if msg.get("content") and msg["content"].strip():
content += msg["content"] + "\n"
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
# (used when native thinking is disabled and model reasons via XML)
content += self._convert_scratchpad_to_think(msg["content"]) + "\n"
# Add tool calls wrapped in XML tags
for tool_call in msg["tool_calls"]:
@ -1206,11 +1247,14 @@ class AIAgent:
# Add <think> tags around reasoning for trajectory storage
content = ""
# Prepend reasoning in <think> tags if available
# Prepend reasoning in <think> tags if available (native thinking tokens)
if msg.get("reasoning") and msg["reasoning"].strip():
content = f"<think>\n{msg['reasoning']}\n</think>\n"
content += msg["content"] or ""
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
# (used when native thinking is disabled and model reasons via XML)
raw_content = msg["content"] or ""
content += self._convert_scratchpad_to_think(raw_content)
trajectory.append({
"from": "gpt",
@ -1261,6 +1305,66 @@ class AIAgent:
except Exception as e:
print(f"⚠️ Failed to save trajectory: {e}")
def _log_api_payload(self, turn_number: int, api_kwargs: Dict[str, Any], response=None):
"""
[TEMPORARY DEBUG] Log the full API payload and response token metrics
for each agent turn to a per-session JSONL file for inspection.
Writes one JSON line per turn to logs/payload_<session_id>.jsonl.
Tool schemas are summarized (just names) to keep logs readable.
Args:
turn_number: Which API call this is (1-indexed)
api_kwargs: The full kwargs dict being passed to chat.completions.create
response: The API response object (optional, added after the call completes)
"""
try:
payload_log_file = self.logs_dir / f"payload_{self.session_id}.jsonl"
# Build a serializable copy of the request payload
payload = {
"turn": turn_number,
"timestamp": datetime.now().isoformat(),
"model": api_kwargs.get("model"),
"max_tokens": api_kwargs.get("max_tokens"),
"extra_body": api_kwargs.get("extra_body"),
"num_tools": len(api_kwargs.get("tools") or []),
"tool_names": [t["function"]["name"] for t in (api_kwargs.get("tools") or [])],
"messages": api_kwargs.get("messages", []),
}
# Add response token metrics if available
if response is not None:
try:
usage_raw = response.usage.model_dump() if hasattr(response.usage, 'model_dump') else {}
payload["response"] = {
# Core token counts
"prompt_tokens": usage_raw.get("prompt_tokens"),
"completion_tokens": usage_raw.get("completion_tokens"),
"total_tokens": usage_raw.get("total_tokens"),
# Completion breakdown (reasoning tokens, etc.)
"completion_tokens_details": usage_raw.get("completion_tokens_details"),
# Prompt breakdown (cached tokens, etc.)
"prompt_tokens_details": usage_raw.get("prompt_tokens_details"),
# Cost tracking
"cost": usage_raw.get("cost"),
"is_byok": usage_raw.get("is_byok"),
"cost_details": usage_raw.get("cost_details"),
# Provider info (top-level field from OpenRouter)
"provider": getattr(response, 'provider', None),
"response_model": getattr(response, 'model', None),
}
except Exception:
payload["response"] = {"error": "failed to extract usage"}
with open(payload_log_file, "a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False, default=str) + "\n")
except Exception as e:
# Silent fail - don't interrupt the agent for debug logging
if self.verbose_logging:
logging.warning(f"Failed to log API payload: {e}")
def _save_session_log(self, messages: List[Dict[str, Any]] = None):
"""
Save the current session trajectory to the logs directory.
@ -1276,10 +1380,12 @@ class AIAgent:
return
try:
# Extract the first user message for the trajectory format
# The first message should be the user's initial query
# Extract the actual user query for the trajectory format.
# Skip prefill messages (they're ephemeral and shouldn't appear in trajectories)
# so the first user message we find is the real task prompt.
first_user_query = ""
for msg in messages:
start_idx = len(self.prefill_messages) if self.prefill_messages else 0
for msg in messages[start_idx:]:
if msg.get("role") == "user":
first_user_query = msg.get("content", "")
break
@ -1373,6 +1479,12 @@ class AIAgent:
# Initialize conversation
messages = conversation_history or []
# Inject prefill messages at the start of conversation (before user's actual prompt)
# This is used for few-shot priming, e.g., a greeting exchange to set response style
if self.prefill_messages and not conversation_history:
for prefill_msg in self.prefill_messages:
messages.append(prefill_msg.copy())
# Add user message
messages.append({
"role": "user",
@ -1493,6 +1605,10 @@ class AIAgent:
"timeout": 600.0 # 10 minute timeout for very long responses
}
# Add max_tokens if configured (overrides model default)
if self.max_tokens is not None:
api_kwargs["max_tokens"] = self.max_tokens
# Add extra_body for OpenRouter (provider preferences + reasoning)
extra_body = {}
@ -1500,12 +1616,17 @@ class AIAgent:
if provider_preferences:
extra_body["provider"] = provider_preferences
# Enable reasoning with xhigh effort for OpenRouter
# Configure reasoning for OpenRouter
# If reasoning_config is explicitly provided, use it (allows disabling/customizing)
# Otherwise, default to xhigh effort for OpenRouter models
if "openrouter" in self.base_url.lower():
extra_body["reasoning"] = {
"enabled": True,
"effort": "xhigh"
}
if self.reasoning_config is not None:
extra_body["reasoning"] = self.reasoning_config
else:
extra_body["reasoning"] = {
"enabled": True,
"effort": "xhigh"
}
if extra_body:
api_kwargs["extra_body"] = extra_body
@ -1527,6 +1648,9 @@ class AIAgent:
# Log response with provider info if available
resp_model = getattr(response, 'model', 'N/A') if response else 'N/A'
logging.debug(f"API Response received - Model: {resp_model}, Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
# [DEBUG] Log the full API payload + response token metrics
self._log_api_payload(api_call_count, api_kwargs, response=response)
# Validate response has valid choices before proceeding
if response is None or not hasattr(response, 'choices') or response.choices is None or len(response.choices) == 0:
@ -1589,7 +1713,20 @@ class AIAgent:
wait_time = min(5 * (2 ** (retry_count - 1)), 120) # 5s, 10s, 20s, 40s, 80s, 120s
print(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...")
logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}")
time.sleep(wait_time)
# Sleep in small increments to stay responsive to interrupts
sleep_end = time.time() + wait_time
while time.time() < sleep_end:
if self._interrupt_requested:
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
return {
"final_response": "Operation interrupted.",
"messages": messages,
"api_calls": api_call_count,
"completed": False,
"interrupted": True,
}
time.sleep(0.2)
continue # Retry the API call
# Check finish_reason before proceeding
@ -1668,6 +1805,41 @@ class AIAgent:
print(f"{self.log_prefix} 📝 Error: {str(api_error)[:200]}")
print(f"{self.log_prefix} 📊 Request context: {len(api_messages)} messages, ~{approx_tokens:,} tokens, {len(self.tools) if self.tools else 0} tools")
# Check for interrupt before deciding to retry
if self._interrupt_requested:
print(f"{self.log_prefix}⚡ Interrupt detected during error handling, aborting retries.")
return {
"final_response": "Operation interrupted.",
"messages": messages,
"api_calls": api_call_count,
"completed": False,
"interrupted": True,
}
# Check for non-retryable client errors (4xx HTTP status codes).
# These indicate a problem with the request itself (bad model ID,
# invalid API key, forbidden, etc.) and will never succeed on retry.
is_client_error = any(phrase in error_msg for phrase in [
'error code: 400', 'error code: 401', 'error code: 403',
'error code: 404', 'error code: 422',
'is not a valid model', 'invalid model', 'model not found',
'invalid api key', 'invalid_api_key', 'authentication',
'unauthorized', 'forbidden', 'not found',
])
if is_client_error:
print(f"{self.log_prefix}❌ Non-retryable client error detected. Aborting immediately.")
print(f"{self.log_prefix} 💡 This type of error won't be fixed by retrying.")
logging.error(f"{self.log_prefix}Non-retryable client error: {api_error}")
return {
"final_response": None,
"messages": messages,
"api_calls": api_call_count,
"completed": False,
"failed": True,
"error": str(api_error),
}
# Check for non-retryable errors (context length exceeded)
is_context_length_error = any(phrase in error_msg for phrase in [
'context length', 'maximum context', 'token limit',
@ -1708,7 +1880,21 @@ class AIAgent:
print(f"⚠️ OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
print(f"⏳ Retrying in {wait_time}s...")
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
time.sleep(wait_time)
# Sleep in small increments so we can respond to interrupts quickly
# instead of blocking the entire wait_time in one sleep() call
sleep_end = time.time() + wait_time
while time.time() < sleep_end:
if self._interrupt_requested:
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
return {
"final_response": "Operation interrupted.",
"messages": messages,
"api_calls": api_call_count,
"completed": False,
"interrupted": True,
}
time.sleep(0.2) # Check interrupt every 200ms
try:
assistant_message = response.choices[0].message
@ -2069,13 +2255,28 @@ class AIAgent:
if self.ephemeral_system_prompt:
api_messages = [{"role": "system", "content": self.ephemeral_system_prompt}] + api_messages
summary_response = self.client.chat.completions.create(
model=self.model,
messages=api_messages,
# Build extra_body for summary call (same reasoning config as main loop)
summary_extra_body = {}
if "openrouter" in self.base_url.lower():
if self.reasoning_config is not None:
summary_extra_body["reasoning"] = self.reasoning_config
else:
summary_extra_body["reasoning"] = {
"enabled": True,
"effort": "xhigh"
}
summary_kwargs = {
"model": self.model,
"messages": api_messages,
# No tools parameter - forces text response
extra_headers=self.extra_headers,
extra_body=self.extra_body,
)
}
if self.max_tokens is not None:
summary_kwargs["max_tokens"] = self.max_tokens
if summary_extra_body:
summary_kwargs["extra_body"] = summary_extra_body
summary_response = self.client.chat.completions.create(**summary_kwargs)
if summary_response.choices and summary_response.choices[0].message.content:
final_response = summary_response.choices[0].message.content
@ -2151,7 +2352,7 @@ class AIAgent:
def main(
query: str = None,
model: str = "anthropic/claude-sonnet-4-20250514",
model: str = "anthropic/claude-opus-4.6",
api_key: str = None,
base_url: str = "https://openrouter.ai/api/v1",
max_turns: int = 10,