Enhance BatchRunner and AIAgent with new configuration options, default model now opus 4.6, default summarizer gemini flash 3
- Added `max_tokens`, `reasoning_config`, and `prefill_messages` parameters to `BatchRunner` and `AIAgent` for improved model response control. - Updated CLI to support new options for reasoning effort and prefill messages from a JSON file. - Modified example configuration files to reflect changes in default model and summary model. - Improved error handling for loading prefill messages and reasoning configurations in the CLI. - Updated documentation to include new parameters and usage examples.
This commit is contained in:
parent
fa76a331b0
commit
f12ea1bc02
7 changed files with 324 additions and 40 deletions
255
run_agent.py
255
run_agent.py
|
|
@ -66,6 +66,7 @@ _MODEL_CACHE_TTL = 3600 # 1 hour cache TTL
|
|||
DEFAULT_CONTEXT_LENGTHS = {
|
||||
"anthropic/claude-opus-4": 200000,
|
||||
"anthropic/claude-opus-4.5": 200000,
|
||||
"anthropic/claude-opus-4.6": 200000,
|
||||
"anthropic/claude-sonnet-4": 200000,
|
||||
"anthropic/claude-sonnet-4-20250514": 200000,
|
||||
"anthropic/claude-haiku-4.5": 200000,
|
||||
|
|
@ -206,7 +207,7 @@ class ContextCompressor:
|
|||
self,
|
||||
model: str,
|
||||
threshold_percent: float = 0.85,
|
||||
summary_model: str = "google/gemini-2.0-flash-001",
|
||||
summary_model: str = "google/gemini-3-flash-preview",
|
||||
protect_first_n: int = 3,
|
||||
protect_last_n: int = 4,
|
||||
summary_target_tokens: int = 500,
|
||||
|
|
@ -584,7 +585,7 @@ class AIAgent:
|
|||
self,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514", # OpenRouter format
|
||||
model: str = "anthropic/claude-opus-4.6", # OpenRouter format
|
||||
max_iterations: int = 60, # Default tool-calling iterations
|
||||
tool_delay: float = 1.0,
|
||||
enabled_toolsets: List[str] = None,
|
||||
|
|
@ -601,6 +602,9 @@ class AIAgent:
|
|||
provider_sort: str = None,
|
||||
session_id: str = None,
|
||||
tool_progress_callback: callable = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
|
|
@ -625,6 +629,12 @@ class AIAgent:
|
|||
provider_sort (str): Sort providers by price/throughput/latency (optional)
|
||||
session_id (str): Pre-generated session ID for logging (optional, auto-generated if not provided)
|
||||
tool_progress_callback (callable): Callback function(tool_name, args_preview) for progress notifications
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_config (Dict): OpenRouter reasoning configuration override (e.g. {"effort": "none"} to disable thinking).
|
||||
If None, defaults to {"enabled": True, "effort": "xhigh"} for OpenRouter. Set to disable/customize reasoning.
|
||||
prefill_messages (List[Dict]): Messages to prepend to conversation history as prefilled context.
|
||||
Useful for injecting a few-shot example or priming the model's response style.
|
||||
Example: [{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "Hello!"}]
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
|
|
@ -653,6 +663,11 @@ class AIAgent:
|
|||
self.enabled_toolsets = enabled_toolsets
|
||||
self.disabled_toolsets = disabled_toolsets
|
||||
|
||||
# Model response configuration
|
||||
self.max_tokens = max_tokens # None = use model default
|
||||
self.reasoning_config = reasoning_config # None = use default (xhigh for OpenRouter)
|
||||
self.prefill_messages = prefill_messages or [] # Prefilled conversation turns
|
||||
|
||||
# Configure logging
|
||||
if self.verbose_logging:
|
||||
logging.basicConfig(
|
||||
|
|
@ -781,7 +796,7 @@ class AIAgent:
|
|||
# Compresses conversation when approaching model's context limit
|
||||
# Configuration via environment variables (can be set in .env or cli-config.yaml)
|
||||
compression_threshold = float(os.getenv("CONTEXT_COMPRESSION_THRESHOLD", "0.85"))
|
||||
compression_model = os.getenv("CONTEXT_COMPRESSION_MODEL", "google/gemini-2.0-flash-001")
|
||||
compression_model = os.getenv("CONTEXT_COMPRESSION_MODEL", "google/gemini-3-flash-preview")
|
||||
compression_enabled = os.getenv("CONTEXT_COMPRESSION_ENABLED", "true").lower() in ("true", "1", "yes")
|
||||
|
||||
self.context_compressor = ContextCompressor(
|
||||
|
|
@ -1086,6 +1101,25 @@ class AIAgent:
|
|||
|
||||
return json.dumps(formatted_tools, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def _convert_scratchpad_to_think(content: str) -> str:
|
||||
"""
|
||||
Convert <REASONING_SCRATCHPAD> tags to <think> tags in content.
|
||||
|
||||
When native thinking/reasoning is disabled and the model is prompted to
|
||||
reason inside <REASONING_SCRATCHPAD> XML tags instead, this converts those
|
||||
to the standard <think> format used in our trajectory storage.
|
||||
|
||||
Args:
|
||||
content: Assistant message content that may contain scratchpad tags
|
||||
|
||||
Returns:
|
||||
Content with scratchpad tags replaced by think tags
|
||||
"""
|
||||
if not content or "<REASONING_SCRATCHPAD>" not in content:
|
||||
return content
|
||||
return content.replace("<REASONING_SCRATCHPAD>", "<think>").replace("</REASONING_SCRATCHPAD>", "</think>")
|
||||
|
||||
def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert internal message format to trajectory format for saving.
|
||||
|
|
@ -1120,14 +1154,19 @@ class AIAgent:
|
|||
"value": system_msg
|
||||
})
|
||||
|
||||
# Add the initial user message
|
||||
# Add the actual user prompt (from the dataset) as the first human message
|
||||
trajectory.append({
|
||||
"from": "human",
|
||||
"value": user_query
|
||||
})
|
||||
|
||||
# Process remaining messages
|
||||
i = 1 # Skip the first user message as we already added it
|
||||
# Calculate where agent responses start in the messages list.
|
||||
# Prefill messages are ephemeral (only used to prime model response style)
|
||||
# so we skip them entirely in the saved trajectory.
|
||||
# Layout: [*prefill_msgs, actual_user_msg, ...agent_responses...]
|
||||
num_prefill = len(self.prefill_messages) if self.prefill_messages else 0
|
||||
i = num_prefill + 1 # Skip prefill messages + the actual user message (already added above)
|
||||
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
|
|
@ -1138,12 +1177,14 @@ class AIAgent:
|
|||
# Add <think> tags around reasoning for trajectory storage
|
||||
content = ""
|
||||
|
||||
# Prepend reasoning in <think> tags if available
|
||||
# Prepend reasoning in <think> tags if available (native thinking tokens)
|
||||
if msg.get("reasoning") and msg["reasoning"].strip():
|
||||
content = f"<think>\n{msg['reasoning']}\n</think>\n"
|
||||
|
||||
if msg.get("content") and msg["content"].strip():
|
||||
content += msg["content"] + "\n"
|
||||
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
|
||||
# (used when native thinking is disabled and model reasons via XML)
|
||||
content += self._convert_scratchpad_to_think(msg["content"]) + "\n"
|
||||
|
||||
# Add tool calls wrapped in XML tags
|
||||
for tool_call in msg["tool_calls"]:
|
||||
|
|
@ -1206,11 +1247,14 @@ class AIAgent:
|
|||
# Add <think> tags around reasoning for trajectory storage
|
||||
content = ""
|
||||
|
||||
# Prepend reasoning in <think> tags if available
|
||||
# Prepend reasoning in <think> tags if available (native thinking tokens)
|
||||
if msg.get("reasoning") and msg["reasoning"].strip():
|
||||
content = f"<think>\n{msg['reasoning']}\n</think>\n"
|
||||
|
||||
content += msg["content"] or ""
|
||||
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
|
||||
# (used when native thinking is disabled and model reasons via XML)
|
||||
raw_content = msg["content"] or ""
|
||||
content += self._convert_scratchpad_to_think(raw_content)
|
||||
|
||||
trajectory.append({
|
||||
"from": "gpt",
|
||||
|
|
@ -1261,6 +1305,66 @@ class AIAgent:
|
|||
except Exception as e:
|
||||
print(f"⚠️ Failed to save trajectory: {e}")
|
||||
|
||||
def _log_api_payload(self, turn_number: int, api_kwargs: Dict[str, Any], response=None):
|
||||
"""
|
||||
[TEMPORARY DEBUG] Log the full API payload and response token metrics
|
||||
for each agent turn to a per-session JSONL file for inspection.
|
||||
|
||||
Writes one JSON line per turn to logs/payload_<session_id>.jsonl.
|
||||
Tool schemas are summarized (just names) to keep logs readable.
|
||||
|
||||
Args:
|
||||
turn_number: Which API call this is (1-indexed)
|
||||
api_kwargs: The full kwargs dict being passed to chat.completions.create
|
||||
response: The API response object (optional, added after the call completes)
|
||||
"""
|
||||
try:
|
||||
payload_log_file = self.logs_dir / f"payload_{self.session_id}.jsonl"
|
||||
|
||||
# Build a serializable copy of the request payload
|
||||
payload = {
|
||||
"turn": turn_number,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": api_kwargs.get("model"),
|
||||
"max_tokens": api_kwargs.get("max_tokens"),
|
||||
"extra_body": api_kwargs.get("extra_body"),
|
||||
"num_tools": len(api_kwargs.get("tools") or []),
|
||||
"tool_names": [t["function"]["name"] for t in (api_kwargs.get("tools") or [])],
|
||||
"messages": api_kwargs.get("messages", []),
|
||||
}
|
||||
|
||||
# Add response token metrics if available
|
||||
if response is not None:
|
||||
try:
|
||||
usage_raw = response.usage.model_dump() if hasattr(response.usage, 'model_dump') else {}
|
||||
payload["response"] = {
|
||||
# Core token counts
|
||||
"prompt_tokens": usage_raw.get("prompt_tokens"),
|
||||
"completion_tokens": usage_raw.get("completion_tokens"),
|
||||
"total_tokens": usage_raw.get("total_tokens"),
|
||||
# Completion breakdown (reasoning tokens, etc.)
|
||||
"completion_tokens_details": usage_raw.get("completion_tokens_details"),
|
||||
# Prompt breakdown (cached tokens, etc.)
|
||||
"prompt_tokens_details": usage_raw.get("prompt_tokens_details"),
|
||||
# Cost tracking
|
||||
"cost": usage_raw.get("cost"),
|
||||
"is_byok": usage_raw.get("is_byok"),
|
||||
"cost_details": usage_raw.get("cost_details"),
|
||||
# Provider info (top-level field from OpenRouter)
|
||||
"provider": getattr(response, 'provider', None),
|
||||
"response_model": getattr(response, 'model', None),
|
||||
}
|
||||
except Exception:
|
||||
payload["response"] = {"error": "failed to extract usage"}
|
||||
|
||||
with open(payload_log_file, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(payload, ensure_ascii=False, default=str) + "\n")
|
||||
|
||||
except Exception as e:
|
||||
# Silent fail - don't interrupt the agent for debug logging
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to log API payload: {e}")
|
||||
|
||||
def _save_session_log(self, messages: List[Dict[str, Any]] = None):
|
||||
"""
|
||||
Save the current session trajectory to the logs directory.
|
||||
|
|
@ -1276,10 +1380,12 @@ class AIAgent:
|
|||
return
|
||||
|
||||
try:
|
||||
# Extract the first user message for the trajectory format
|
||||
# The first message should be the user's initial query
|
||||
# Extract the actual user query for the trajectory format.
|
||||
# Skip prefill messages (they're ephemeral and shouldn't appear in trajectories)
|
||||
# so the first user message we find is the real task prompt.
|
||||
first_user_query = ""
|
||||
for msg in messages:
|
||||
start_idx = len(self.prefill_messages) if self.prefill_messages else 0
|
||||
for msg in messages[start_idx:]:
|
||||
if msg.get("role") == "user":
|
||||
first_user_query = msg.get("content", "")
|
||||
break
|
||||
|
|
@ -1373,6 +1479,12 @@ class AIAgent:
|
|||
# Initialize conversation
|
||||
messages = conversation_history or []
|
||||
|
||||
# Inject prefill messages at the start of conversation (before user's actual prompt)
|
||||
# This is used for few-shot priming, e.g., a greeting exchange to set response style
|
||||
if self.prefill_messages and not conversation_history:
|
||||
for prefill_msg in self.prefill_messages:
|
||||
messages.append(prefill_msg.copy())
|
||||
|
||||
# Add user message
|
||||
messages.append({
|
||||
"role": "user",
|
||||
|
|
@ -1493,6 +1605,10 @@ class AIAgent:
|
|||
"timeout": 600.0 # 10 minute timeout for very long responses
|
||||
}
|
||||
|
||||
# Add max_tokens if configured (overrides model default)
|
||||
if self.max_tokens is not None:
|
||||
api_kwargs["max_tokens"] = self.max_tokens
|
||||
|
||||
# Add extra_body for OpenRouter (provider preferences + reasoning)
|
||||
extra_body = {}
|
||||
|
||||
|
|
@ -1500,12 +1616,17 @@ class AIAgent:
|
|||
if provider_preferences:
|
||||
extra_body["provider"] = provider_preferences
|
||||
|
||||
# Enable reasoning with xhigh effort for OpenRouter
|
||||
# Configure reasoning for OpenRouter
|
||||
# If reasoning_config is explicitly provided, use it (allows disabling/customizing)
|
||||
# Otherwise, default to xhigh effort for OpenRouter models
|
||||
if "openrouter" in self.base_url.lower():
|
||||
extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
if self.reasoning_config is not None:
|
||||
extra_body["reasoning"] = self.reasoning_config
|
||||
else:
|
||||
extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
|
||||
if extra_body:
|
||||
api_kwargs["extra_body"] = extra_body
|
||||
|
|
@ -1527,6 +1648,9 @@ class AIAgent:
|
|||
# Log response with provider info if available
|
||||
resp_model = getattr(response, 'model', 'N/A') if response else 'N/A'
|
||||
logging.debug(f"API Response received - Model: {resp_model}, Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
|
||||
|
||||
# [DEBUG] Log the full API payload + response token metrics
|
||||
self._log_api_payload(api_call_count, api_kwargs, response=response)
|
||||
|
||||
# Validate response has valid choices before proceeding
|
||||
if response is None or not hasattr(response, 'choices') or response.choices is None or len(response.choices) == 0:
|
||||
|
|
@ -1589,7 +1713,20 @@ class AIAgent:
|
|||
wait_time = min(5 * (2 ** (retry_count - 1)), 120) # 5s, 10s, 20s, 40s, 80s, 120s
|
||||
print(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...")
|
||||
logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Sleep in small increments to stay responsive to interrupts
|
||||
sleep_end = time.time() + wait_time
|
||||
while time.time() < sleep_end:
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"interrupted": True,
|
||||
}
|
||||
time.sleep(0.2)
|
||||
continue # Retry the API call
|
||||
|
||||
# Check finish_reason before proceeding
|
||||
|
|
@ -1668,6 +1805,41 @@ class AIAgent:
|
|||
print(f"{self.log_prefix} 📝 Error: {str(api_error)[:200]}")
|
||||
print(f"{self.log_prefix} 📊 Request context: {len(api_messages)} messages, ~{approx_tokens:,} tokens, {len(self.tools) if self.tools else 0} tools")
|
||||
|
||||
# Check for interrupt before deciding to retry
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during error handling, aborting retries.")
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"interrupted": True,
|
||||
}
|
||||
|
||||
# Check for non-retryable client errors (4xx HTTP status codes).
|
||||
# These indicate a problem with the request itself (bad model ID,
|
||||
# invalid API key, forbidden, etc.) and will never succeed on retry.
|
||||
is_client_error = any(phrase in error_msg for phrase in [
|
||||
'error code: 400', 'error code: 401', 'error code: 403',
|
||||
'error code: 404', 'error code: 422',
|
||||
'is not a valid model', 'invalid model', 'model not found',
|
||||
'invalid api key', 'invalid_api_key', 'authentication',
|
||||
'unauthorized', 'forbidden', 'not found',
|
||||
])
|
||||
|
||||
if is_client_error:
|
||||
print(f"{self.log_prefix}❌ Non-retryable client error detected. Aborting immediately.")
|
||||
print(f"{self.log_prefix} 💡 This type of error won't be fixed by retrying.")
|
||||
logging.error(f"{self.log_prefix}Non-retryable client error: {api_error}")
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"failed": True,
|
||||
"error": str(api_error),
|
||||
}
|
||||
|
||||
# Check for non-retryable errors (context length exceeded)
|
||||
is_context_length_error = any(phrase in error_msg for phrase in [
|
||||
'context length', 'maximum context', 'token limit',
|
||||
|
|
@ -1708,7 +1880,21 @@ class AIAgent:
|
|||
print(f"⚠️ OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
|
||||
print(f"⏳ Retrying in {wait_time}s...")
|
||||
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Sleep in small increments so we can respond to interrupts quickly
|
||||
# instead of blocking the entire wait_time in one sleep() call
|
||||
sleep_end = time.time() + wait_time
|
||||
while time.time() < sleep_end:
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"interrupted": True,
|
||||
}
|
||||
time.sleep(0.2) # Check interrupt every 200ms
|
||||
|
||||
try:
|
||||
assistant_message = response.choices[0].message
|
||||
|
|
@ -2069,13 +2255,28 @@ class AIAgent:
|
|||
if self.ephemeral_system_prompt:
|
||||
api_messages = [{"role": "system", "content": self.ephemeral_system_prompt}] + api_messages
|
||||
|
||||
summary_response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
# Build extra_body for summary call (same reasoning config as main loop)
|
||||
summary_extra_body = {}
|
||||
if "openrouter" in self.base_url.lower():
|
||||
if self.reasoning_config is not None:
|
||||
summary_extra_body["reasoning"] = self.reasoning_config
|
||||
else:
|
||||
summary_extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
|
||||
summary_kwargs = {
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
# No tools parameter - forces text response
|
||||
extra_headers=self.extra_headers,
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
}
|
||||
if self.max_tokens is not None:
|
||||
summary_kwargs["max_tokens"] = self.max_tokens
|
||||
if summary_extra_body:
|
||||
summary_kwargs["extra_body"] = summary_extra_body
|
||||
|
||||
summary_response = self.client.chat.completions.create(**summary_kwargs)
|
||||
|
||||
if summary_response.choices and summary_response.choices[0].message.content:
|
||||
final_response = summary_response.choices[0].message.content
|
||||
|
|
@ -2151,7 +2352,7 @@ class AIAgent:
|
|||
|
||||
def main(
|
||||
query: str = None,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
model: str = "anthropic/claude-opus-4.6",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_turns: int = 10,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue