From f018999da97862bfc919a8eaddfec57ce0cdea18 Mon Sep 17 00:00:00 2001 From: teknium1 Date: Tue, 3 Feb 2026 23:41:26 -0800 Subject: [PATCH 1/6] initial RL training tools and loop --- model_tools.py | 360 ++++++++++++++++++++++++++++++- rl_cli.py | 363 +++++++++++++++++++++++++++++++ tools/__init__.py | 31 +++ tools/rl_training_tool.py | 436 ++++++++++++++++++++++++++++++++++++++ toolsets.py | 12 ++ 5 files changed, 1199 insertions(+), 3 deletions(-) create mode 100644 rl_cli.py create mode 100644 tools/rl_training_tool.py diff --git a/model_tools.py b/model_tools.py index e78323f6..ebabaf56 100644 --- a/model_tools.py +++ b/model_tools.py @@ -39,6 +39,21 @@ from tools.vision_tools import vision_analyze_tool, check_vision_requirements from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements from tools.skills_tool import skills_categories, skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION +# RL Training tools (Tinker-Atropos) +from tools.rl_training_tool import ( + rl_list_environments, + rl_select_environment, + rl_get_current_config, + rl_edit_config, + rl_start_training, + rl_check_status, + rl_stop_training, + rl_get_results, + rl_test_inference, + rl_list_runs, + rl_health_check, + check_rl_api_keys, +) # Cronjob management tools (CLI-only) from tools.cronjob_tools import ( schedule_cronjob, @@ -128,6 +143,19 @@ TOOLSET_REQUIREMENTS = { "setup_url": None, "tools": ["skills_categories", "skills_list", "skill_view"], }, + "rl": { + "name": "RL Training (Tinker-Atropos)", + "env_vars": ["TINKER_API_KEY", "WANDB_API_KEY"], + "check_fn": check_rl_api_keys, + "setup_url": "https://wandb.ai/authorize", + "tools": [ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs", + ], + }, } @@ -471,6 +499,199 @@ def get_cronjob_tool_definitions_formatted() -> List[Dict[str, Any]]: ]] +def get_rl_tool_definitions() -> List[Dict[str, Any]]: + """ + Get tool definitions for RL training tools in OpenAI's expected format. + + These tools enable running RL training through Tinker-Atropos. + + Returns: + List[Dict]: List of RL tool definitions compatible with OpenAI API + """ + return [ + { + "type": "function", + "function": { + "name": "rl_list_environments", + "description": "List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards).", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_select_environment", + "description": "Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them.", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the environment to select (from rl_list_environments)" + } + }, + "required": ["name"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_get_current_config", + "description": "Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers.", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_edit_config", + "description": "Update a configuration field. Valid fields: group_size (int), max_token_length (int), total_steps (int), steps_per_eval (int), use_wandb (bool), wandb_name (str), max_num_workers (int).", + "parameters": { + "type": "object", + "properties": { + "field": { + "type": "string", + "description": "Name of the field to update" + }, + "value": { + "description": "New value for the field" + } + }, + "required": ["field", "value"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_start_training", + "description": "Start a new RL training run. WARNING: Training can take hours. Use rl_check_status() to monitor (30-minute intervals recommended). Test with rl_test_inference() first!", + "parameters": { + "type": "object", + "properties": { + "wandb_project": { + "type": "string", + "description": "WandB project name for logging", + "default": "rl-training" + }, + "lora_rank": { + "type": "integer", + "description": "LoRA rank for training", + "default": 32 + }, + "learning_rate": { + "type": "number", + "description": "Learning rate", + "default": 4e-5 + } + }, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_check_status", + "description": "Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct.", + "parameters": { + "type": "object", + "properties": { + "run_id": { + "type": "string", + "description": "The run ID from rl_start_training()" + } + }, + "required": ["run_id"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_stop_training", + "description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.", + "parameters": { + "type": "object", + "properties": { + "run_id": { + "type": "string", + "description": "The run ID to stop" + } + }, + "required": ["run_id"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_get_results", + "description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.", + "parameters": { + "type": "object", + "properties": { + "run_id": { + "type": "string", + "description": "The run ID to get results for" + } + }, + "required": ["run_id"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_test_inference", + "description": "Test inference + verifier on sample prompts WITHOUT full training. Use to validate environments before committing to long training runs. Tests data loading, inference, and verifier logic.", + "parameters": { + "type": "object", + "properties": { + "prompts": { + "type": "array", + "items": {"type": "string"}, + "description": "List of test prompts to run through the environment" + }, + "max_tokens": { + "type": "integer", + "description": "Maximum tokens to generate per prompt", + "default": 256 + }, + "temperature": { + "type": "number", + "description": "Sampling temperature", + "default": 1.0 + } + }, + "required": ["prompts"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_list_runs", + "description": "List all training runs (active and completed) with their status.", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + } + ] + + def get_all_tool_names() -> List[str]: """ Get the names of all available tools across all toolsets. @@ -519,6 +740,16 @@ def get_all_tool_names() -> List[str]: "schedule_cronjob", "list_cronjobs", "remove_cronjob" ]) + # RL Training tools + if check_rl_api_keys(): + tool_names.extend([ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs" + ]) + return tool_names @@ -557,7 +788,18 @@ def get_toolset_for_tool(tool_name: str) -> str: # Cronjob management tools "schedule_cronjob": "cronjob_tools", "list_cronjobs": "cronjob_tools", - "remove_cronjob": "cronjob_tools" + "remove_cronjob": "cronjob_tools", + # RL Training tools + "rl_list_environments": "rl_tools", + "rl_select_environment": "rl_tools", + "rl_get_current_config": "rl_tools", + "rl_edit_config": "rl_tools", + "rl_start_training": "rl_tools", + "rl_check_status": "rl_tools", + "rl_stop_training": "rl_tools", + "rl_get_results": "rl_tools", + "rl_test_inference": "rl_tools", + "rl_list_runs": "rl_tools", } return toolset_mapping.get(tool_name, "unknown") @@ -635,6 +877,11 @@ def get_tool_definitions( for tool in get_cronjob_tool_definitions_formatted(): all_available_tools_map[tool["function"]["name"]] = tool + # RL Training tools + if check_rl_api_keys(): + for tool in get_rl_tool_definitions(): + all_available_tools_map[tool["function"]["name"]] = tool + # Determine which tools to include based on toolsets tools_to_include = set() @@ -663,7 +910,14 @@ def get_tool_definitions( "browser_press", "browser_close", "browser_get_images", "browser_vision" ], - "cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"] + "cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"], + "rl_tools": [ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs" + ] } legacy_tools = legacy_map.get(toolset_name, []) tools_to_include.update(legacy_tools) @@ -708,7 +962,14 @@ def get_tool_definitions( "browser_press", "browser_close", "browser_get_images", "browser_vision" ], - "cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"] + "cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"], + "rl_tools": [ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs" + ] } legacy_tools = legacy_map.get(toolset_name, []) tools_to_include.difference_update(legacy_tools) @@ -1018,6 +1279,89 @@ def handle_cronjob_function_call( return json.dumps({"error": f"Unknown cronjob function: {function_name}"}, ensure_ascii=False) +def handle_rl_function_call( + function_name: str, + function_args: Dict[str, Any] +) -> str: + """ + Handle function calls for RL training tools. + + These tools communicate with the RL API server to manage training runs. + + Args: + function_name (str): Name of the RL function to call + function_args (Dict): Arguments for the function + + Returns: + str: Function result as JSON string + """ + # Run async functions in event loop + import asyncio + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if function_name == "rl_list_environments": + return loop.run_until_complete(rl_list_environments()) + + elif function_name == "rl_select_environment": + return loop.run_until_complete( + rl_select_environment(name=function_args.get("name", "")) + ) + + elif function_name == "rl_get_current_config": + return loop.run_until_complete(rl_get_current_config()) + + elif function_name == "rl_edit_config": + return loop.run_until_complete( + rl_edit_config( + field=function_args.get("field", ""), + value=function_args.get("value") + ) + ) + + elif function_name == "rl_start_training": + return loop.run_until_complete( + rl_start_training( + wandb_project=function_args.get("wandb_project", "rl-training"), + lora_rank=function_args.get("lora_rank", 32), + learning_rate=function_args.get("learning_rate", 4e-5) + ) + ) + + elif function_name == "rl_check_status": + return loop.run_until_complete( + rl_check_status(run_id=function_args.get("run_id", "")) + ) + + elif function_name == "rl_stop_training": + return loop.run_until_complete( + rl_stop_training(run_id=function_args.get("run_id", "")) + ) + + elif function_name == "rl_get_results": + return loop.run_until_complete( + rl_get_results(run_id=function_args.get("run_id", "")) + ) + + elif function_name == "rl_test_inference": + return loop.run_until_complete( + rl_test_inference( + prompts=function_args.get("prompts", []), + max_tokens=function_args.get("max_tokens", 256), + temperature=function_args.get("temperature", 1.0) + ) + ) + + elif function_name == "rl_list_runs": + return loop.run_until_complete(rl_list_runs()) + + return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False) + + def handle_function_call( function_name: str, function_args: Dict[str, Any], @@ -1081,6 +1425,16 @@ def handle_function_call( elif function_name in ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]: return handle_cronjob_function_call(function_name, function_args, task_id) + # Route RL training tools + elif function_name in [ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs" + ]: + return handle_rl_function_call(function_name, function_args) + else: error_msg = f"Unknown function: {function_name}" print(f"āŒ {error_msg}") diff --git a/rl_cli.py b/rl_cli.py new file mode 100644 index 00000000..cd76c91d --- /dev/null +++ b/rl_cli.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +""" +RL Training CLI Runner + +Dedicated CLI runner for RL training workflows with: +- Extended timeouts for long-running training +- RL-focused system prompts +- Full toolset including RL training tools +- Special handling for 30-minute check intervals + +Usage: + python rl_cli.py "Train a model on GSM8k for math reasoning" + python rl_cli.py --interactive + python rl_cli.py --list-environments + +Environment Variables: + TINKER_API_KEY: API key for Tinker service (required) + WANDB_API_KEY: API key for WandB metrics (required) + RL_API_URL: URL of RL API server (default: http://localhost:8080) + OPENROUTER_API_KEY: API key for OpenRouter (required for agent) +""" + +import asyncio +import os +import sys +from pathlib import Path + +import fire + +# Load environment variables from .env file +from dotenv import load_dotenv + +env_path = Path(__file__).parent / '.env' +if env_path.exists(): + load_dotenv(dotenv_path=env_path) + print(f"āœ… Loaded environment variables from {env_path}") + +# Import agent and tools +from run_agent import AIAgent +from model_tools import get_tool_definitions, check_toolset_requirements +from tools.rl_training_tool import check_rl_api_keys, get_missing_keys, rl_health_check + + +# ============================================================================ +# RL-Specific Configuration +# ============================================================================ + +# Extended timeouts for long-running RL operations +RL_MAX_ITERATIONS = 200 # Allow many more iterations for long workflows + +# RL-focused system prompt +RL_SYSTEM_PROMPT = """You are an automated post-training engineer specializing in reinforcement learning for language models. + +## Your Capabilities + +You have access to RL training tools for running reinforcement learning on models through Tinker-Atropos: + +1. **DISCOVER**: Use `rl_list_environments` to see available RL environments +2. **INSPECT**: Read environment files to understand how they work (verifiers, data loading, rewards) +3. **INSPECT DATA**: Use terminal to explore HuggingFace datasets and understand their format +4. **CREATE**: Copy existing environments as templates, modify for your needs +5. **CONFIGURE**: Use `rl_select_environment` and `rl_edit_config` to set up training +6. **TEST**: Always use `rl_test_inference` before full training to validate your setup +7. **TRAIN**: Use `rl_start_training` to begin, `rl_check_status` to monitor +8. **EVALUATE**: Use `rl_get_results` and analyze WandB metrics to assess performance + +## Environment Files + +Environment files are located in: `tinker-atropos/tinker_atropos/environments/` + +Study existing environments to learn patterns. Look for: +- `load_dataset()` calls - how data is loaded +- `score_answer()` / `score()` - verification logic +- `get_next_item()` - prompt formatting +- `system_prompt` - instruction format +- `config_init()` - default configuration + +## Creating New Environments + +To create a new environment: +1. Read an existing environment file (e.g., gsm8k_tinker.py) +2. Use terminal to explore the target dataset format +3. Copy the environment file as a template +4. Modify the dataset loading, prompt formatting, and verifier logic +5. Test with `rl_test_inference` before training + +## Important Guidelines + +- **Always test before training**: Training runs take hours - verify everything works first +- **Monitor metrics**: Check WandB for reward/mean and percent_correct +- **Status check intervals**: Wait at least 30 minutes between status checks +- **Early stopping**: Stop training early if metrics look bad or stagnant +- **Iterate quickly**: Start with small total_steps to validate, then scale up + +## Available Toolsets + +You have access to: +- **RL tools**: Environment discovery, config management, training, testing +- **Terminal**: Run commands, inspect files, explore datasets +- **Web**: Search for information, documentation, papers +- **File tools**: Read and modify code files + +When asked to train a model, follow this workflow: +1. List available environments +2. Select and configure the appropriate environment +3. Test with sample prompts +4. Start training with conservative settings +5. Monitor progress and adjust as needed +""" + +# Toolsets to enable for RL workflows +RL_TOOLSETS = ["base", "terminal", "web", "rl"] + + +# ============================================================================ +# Helper Functions +# ============================================================================ + +def check_requirements(): + """Check that all required environment variables and services are available.""" + errors = [] + + # Check API keys + if not os.getenv("OPENROUTER_API_KEY"): + errors.append("OPENROUTER_API_KEY not set - required for agent") + + missing_rl_keys = get_missing_keys() + if missing_rl_keys: + errors.append(f"Missing RL API keys: {', '.join(missing_rl_keys)}") + + if errors: + print("āŒ Missing requirements:") + for error in errors: + print(f" - {error}") + print("\nPlease set these environment variables in your .env file or shell.") + return False + + return True + + +async def check_rl_server(): + """Check if the RL API server is running.""" + try: + result = await rl_health_check() + import json + data = json.loads(result) + if "error" in data: + return False, data["error"] + return True, data + except Exception as e: + return False, str(e) + + +def list_environments_sync(): + """List available environments (synchronous wrapper).""" + from tools.rl_training_tool import rl_list_environments + import json + + async def _list(): + result = await rl_list_environments() + return json.loads(result) + + return asyncio.run(_list()) + + +# ============================================================================ +# Main CLI +# ============================================================================ + +def main( + task: str = None, + model: str = "anthropic/claude-sonnet-4-20250514", + api_key: str = None, + base_url: str = "https://openrouter.ai/api/v1", + max_iterations: int = RL_MAX_ITERATIONS, + interactive: bool = False, + list_environments: bool = False, + check_server: bool = False, + verbose: bool = False, + save_trajectories: bool = True, +): + """ + RL Training CLI - Dedicated runner for RL training workflows. + + Args: + task: The training task/goal (e.g., "Train a model on GSM8k for math") + model: Model to use for the agent (default: claude-sonnet-4) + api_key: OpenRouter API key (uses OPENROUTER_API_KEY env var if not provided) + base_url: API base URL (default: OpenRouter) + max_iterations: Maximum agent iterations (default: 200 for long workflows) + interactive: Run in interactive mode (multiple conversations) + list_environments: Just list available RL environments and exit + check_server: Check if RL API server is running and exit + verbose: Enable verbose logging + save_trajectories: Save conversation trajectories (default: True for RL) + + Examples: + # Train on a specific environment + python rl_cli.py "Train a model on GSM8k math problems" + + # Interactive mode + python rl_cli.py --interactive + + # List available environments + python rl_cli.py --list-environments + + # Check server status + python rl_cli.py --check-server + """ + print("šŸŽÆ RL Training Agent") + print("=" * 60) + + # Handle server check + if check_server: + print("\nšŸ” Checking RL API server...") + ok, result = asyncio.run(check_rl_server()) + if ok: + print("āœ… RL API server is running") + print(f" Environments discovered: {result.get('environments_discovered', 'unknown')}") + print(f" Current environment: {result.get('current_environment', 'none')}") + print(f" Active runs: {result.get('active_runs', 0)}") + else: + print(f"āŒ RL API server not accessible: {result}") + print("\nTo start the server:") + print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") + return + + # Handle environment listing + if list_environments: + print("\nšŸ“‹ Available RL Environments:") + print("-" * 40) + try: + data = list_environments_sync() + if "error" in data: + print(f"āŒ Error: {data['error']}") + return + + envs = data.get("environments", []) + if not envs: + print("No environments found.") + print("\nMake sure the RL API server is running:") + print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") + return + + for env in envs: + print(f"\n šŸ“¦ {env['name']}") + print(f" Class: {env['class_name']}") + print(f" Path: {env['file_path']}") + if env.get('description'): + desc = env['description'][:100] + "..." if len(env.get('description', '')) > 100 else env.get('description', '') + print(f" Description: {desc}") + + print(f"\nšŸ“Š Total: {len(envs)} environments") + print("\nUse `rl_select_environment(name)` to select an environment for training.") + except Exception as e: + print(f"āŒ Error listing environments: {e}") + print("\nMake sure the RL API server is running:") + print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") + return + + # Check requirements + if not check_requirements(): + sys.exit(1) + + # Set default task if none provided + if not task and not interactive: + print("\nāš ļø No task provided. Use --interactive for interactive mode or provide a task.") + print("\nExamples:") + print(' python rl_cli.py "Train a model on GSM8k math problems"') + print(' python rl_cli.py "Create an RL environment for code generation"') + print(' python rl_cli.py --interactive') + return + + # Get API key + api_key = api_key or os.getenv("OPENROUTER_API_KEY") + if not api_key: + print("āŒ No API key provided. Set OPENROUTER_API_KEY or pass --api-key") + sys.exit(1) + + print(f"\nšŸ¤– Model: {model}") + print(f"šŸ”§ Max iterations: {max_iterations}") + print(f"šŸ“ Toolsets: {', '.join(RL_TOOLSETS)}") + print("=" * 60) + + # Create agent with RL configuration + agent = AIAgent( + base_url=base_url, + api_key=api_key, + model=model, + max_iterations=max_iterations, + enabled_toolsets=RL_TOOLSETS, + save_trajectories=save_trajectories, + verbose_logging=verbose, + quiet_mode=False, + ephemeral_system_prompt=RL_SYSTEM_PROMPT, + ) + + if interactive: + # Interactive mode - multiple conversations + print("\nšŸ”„ Interactive RL Training Mode") + print("Type 'quit' or 'exit' to end the session.") + print("Type 'status' to check active training runs.") + print("-" * 40) + + while True: + try: + user_input = input("\nšŸŽÆ RL Task> ").strip() + + if not user_input: + continue + + if user_input.lower() in ('quit', 'exit', 'q'): + print("\nšŸ‘‹ Goodbye!") + break + + if user_input.lower() == 'status': + # Quick status check + from tools.rl_training_tool import rl_list_runs + import json + result = asyncio.run(rl_list_runs()) + runs = json.loads(result) + if isinstance(runs, list) and runs: + print("\nšŸ“Š Active Runs:") + for run in runs: + print(f" - {run['run_id']}: {run['environment']} ({run['status']})") + else: + print("\nNo active runs.") + continue + + # Run the agent + print("\n" + "=" * 60) + response = agent.run_conversation(user_input) + print("\n" + "=" * 60) + + except KeyboardInterrupt: + print("\n\nšŸ‘‹ Interrupted. Goodbye!") + break + except Exception as e: + print(f"\nāŒ Error: {e}") + if verbose: + import traceback + traceback.print_exc() + else: + # Single task mode + print(f"\nšŸ“ Task: {task}") + print("-" * 40) + + try: + response = agent.run_conversation(task) + print("\n" + "=" * 60) + print("āœ… Task completed") + except KeyboardInterrupt: + print("\n\nāš ļø Interrupted by user") + except Exception as e: + print(f"\nāŒ Error: {e}") + if verbose: + import traceback + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/tools/__init__.py b/tools/__init__.py index 3365dab4..dd8bb4da 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -95,6 +95,23 @@ from .cronjob_tools import ( REMOVE_CRONJOB_SCHEMA ) +# RL Training tools (Tinker-Atropos) +from .rl_training_tool import ( + rl_list_environments, + rl_select_environment, + rl_get_current_config, + rl_edit_config, + rl_start_training, + rl_check_status, + rl_stop_training, + rl_get_results, + rl_test_inference, + rl_list_runs, + rl_health_check, + check_rl_api_keys, + get_missing_keys, +) + __all__ = [ # Web tools 'web_search_tool', @@ -152,5 +169,19 @@ __all__ = [ 'SCHEDULE_CRONJOB_SCHEMA', 'LIST_CRONJOBS_SCHEMA', 'REMOVE_CRONJOB_SCHEMA', + # RL Training tools + 'rl_list_environments', + 'rl_select_environment', + 'rl_get_current_config', + 'rl_edit_config', + 'rl_start_training', + 'rl_check_status', + 'rl_stop_training', + 'rl_get_results', + 'rl_test_inference', + 'rl_list_runs', + 'rl_health_check', + 'check_rl_api_keys', + 'get_missing_keys', ] diff --git a/tools/rl_training_tool.py b/tools/rl_training_tool.py new file mode 100644 index 00000000..1b7401c1 --- /dev/null +++ b/tools/rl_training_tool.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 +""" +RL Training Tools Module + +This module provides tools for running RL training through Tinker-Atropos. +Communicates with the RL API server (rl_api_server.py) to manage: +- Environment discovery and selection +- Configuration management +- Training run lifecycle +- WandB metrics monitoring +- Inference-only testing + +Required environment variables: +- TINKER_API_KEY: API key for Tinker service +- WANDB_API_KEY: API key for Weights & Biases metrics + +Optional environment variables: +- RL_API_URL: URL of the RL API server (default: http://localhost:8080) +- WANDB_ENTITY: WandB entity/team name +- WANDB_PROJECT: Default WandB project name + +Usage: + from tools.rl_training_tool import ( + rl_list_environments, + rl_select_environment, + rl_get_current_config, + rl_edit_config, + rl_start_training, + rl_check_status, + rl_stop_training, + rl_get_results, + rl_test_inference, + ) +""" + +import json +import os +import time +from typing import Any, Dict, List, Optional + +import aiohttp + +# ============================================================================ +# Configuration +# ============================================================================ + +# Default RL API server URL (can be overridden via environment variable) +RL_API_URL = os.getenv("RL_API_URL", "http://localhost:8080") + +# Rate limiting for status checks (30 minutes in seconds) +MIN_STATUS_CHECK_INTERVAL = 30 * 60 +_last_status_check: Dict[str, float] = {} + + +# ============================================================================ +# Helper Functions +# ============================================================================ + +async def _make_request( + method: str, + endpoint: str, + data: Optional[Dict] = None, + timeout: int = 30, +) -> Dict[str, Any]: + """Make an HTTP request to the RL API server.""" + url = f"{RL_API_URL}{endpoint}" + + async with aiohttp.ClientSession() as session: + try: + if method == "GET": + async with session.get(url, timeout=timeout) as response: + if response.status == 200: + return await response.json() + else: + error_text = await response.text() + return {"error": f"HTTP {response.status}: {error_text}"} + elif method == "POST": + async with session.post(url, json=data, timeout=timeout) as response: + if response.status == 200: + return await response.json() + else: + error_text = await response.text() + return {"error": f"HTTP {response.status}: {error_text}"} + except aiohttp.ClientConnectorError: + return { + "error": f"Cannot connect to RL API server at {RL_API_URL}. " + "Make sure the server is running: " + "cd tinker-atropos && uvicorn rl_api_server:app --port 8080" + } + except Exception as e: + return {"error": f"Request failed: {str(e)}"} + + +# ============================================================================ +# Environment Discovery Tools +# ============================================================================ + +async def rl_list_environments() -> str: + """ + List all available RL environments. + + Scans tinker-atropos/tinker_atropos/environments/ for Python files + containing classes that inherit from BaseEnv. + + Returns information about each environment including: + - name: Environment identifier + - class_name: Python class name + - file_path: Path to the environment file + - description: Brief description if available + + TIP: To create or modify RL environments: + 1. Use terminal/file tools to inspect existing environments + 2. Study how they load datasets, define verifiers, and structure rewards + 3. Inspect HuggingFace datasets to understand data formats + 4. Copy an existing environment as a template + 5. Test with rl_test_inference before running full training + + Returns: + JSON string with list of environments or error message + """ + result = await _make_request("GET", "/environments") + + if "error" in result: + return json.dumps(result, indent=2) + + # Add helpful tips to the response + response = { + "environments": result, + "count": len(result), + "tips": [ + "Use rl_select_environment(name) to select an environment", + "Read the file_path with file tools to understand how each environment works", + "Look for load_dataset(), score_answer(), get_next_item() methods", + ] + } + + return json.dumps(response, indent=2) + + +async def rl_select_environment(name: str) -> str: + """ + Select an RL environment for training. + + This loads the environment's default configuration into the config state. + After selecting, use rl_get_current_config() to see the configuration + and rl_edit_config() to modify specific fields. + + Args: + name: Name of the environment to select (from rl_list_environments) + + Returns: + JSON string with selection result, file path, and current config + + TIP: Read the returned file_path to understand how the environment works: + - How it loads data (load_dataset calls) + - How it verifies answers (score_answer method) + - What prompts it uses (system_prompt, get_next_item) + """ + result = await _make_request("POST", f"/environments/{name}/select") + return json.dumps(result, indent=2) + + +# ============================================================================ +# Configuration Tools +# ============================================================================ + +async def rl_get_current_config() -> str: + """ + Get the current environment configuration. + + Returns only the fields that are safe to modify. Other fields + (tokenizer_name, rollout_server_url, etc.) are fixed by the system. + + Available fields: + - group_size: Rollouts per prompt (4-16 typical) + - max_token_length: Max generation tokens (2048-16384) + - total_steps: Training steps (50-2000) + - steps_per_eval: Steps between evaluations + - use_wandb: Enable WandB logging + - wandb_name: WandB run name prefix + - max_num_workers: Concurrent workers (-1 = auto) + + Returns: + JSON string with current config fields and their values + """ + result = await _make_request("GET", "/config") + return json.dumps(result, indent=2) + + +async def rl_edit_config(field: str, value: Any) -> str: + """ + Update a configuration field. + + Only exposed fields can be modified. Validates field name and type. + + Args: + field: Name of the field to update (e.g., "group_size", "total_steps") + value: New value for the field + + Valid fields: + - group_size (int): Rollouts per prompt + - max_token_length (int): Max generation tokens + - total_steps (int): Training steps + - steps_per_eval (int): Eval frequency + - use_wandb (bool): Enable logging + - wandb_name (str): Run name prefix + - max_num_workers (int): Workers count + + Returns: + JSON string with updated config or error message + """ + result = await _make_request("POST", "/config", {"field": field, "value": value}) + return json.dumps(result, indent=2) + + +# ============================================================================ +# Training Management Tools +# ============================================================================ + +async def rl_start_training( + wandb_project: str = "rl-training", + lora_rank: int = 32, + learning_rate: float = 4e-5, +) -> str: + """ + Start a new RL training run with the current environment and config. + + Requires an environment to be selected first using rl_select_environment(). + + WARNING: Training runs can take hours to days. Use rl_check_status() to + monitor progress (recommended: check every 30 minutes at most). + + Args: + wandb_project: WandB project name for logging + lora_rank: LoRA rank for training (default: 32) + learning_rate: Learning rate (default: 4e-5) + + Returns: + JSON string with run_id and initial status + + TIP: Before starting training: + 1. Test with rl_test_inference() to verify the environment works + 2. Start with fewer total_steps to validate the setup + 3. Monitor WandB metrics for reward/mean and percent_correct + """ + result = await _make_request("POST", "/runs", { + "wandb_project": wandb_project, + "lora_rank": lora_rank, + "learning_rate": learning_rate, + }) + return json.dumps(result, indent=2) + + +async def rl_check_status(run_id: str) -> str: + """ + Get status and metrics for a training run. + + RATE LIMITED: For long-running training, this function enforces a + minimum 30-minute interval between checks for the same run_id. + + Fetches latest metrics from WandB if available: + - step: Current training step + - state: Run state (running, finished, crashed) + - reward_mean: Average reward across batches + - loss: Training loss + - percent_correct: Training accuracy + - eval_percent_correct: Evaluation accuracy + + Args: + run_id: The run ID returned by rl_start_training() + + Returns: + JSON string with run status and metrics, or rate limit message + """ + global _last_status_check + + # Check rate limiting + now = time.time() + if run_id in _last_status_check: + elapsed = now - _last_status_check[run_id] + if elapsed < MIN_STATUS_CHECK_INTERVAL: + remaining = MIN_STATUS_CHECK_INTERVAL - elapsed + return json.dumps({ + "rate_limited": True, + "run_id": run_id, + "message": f"Rate limited. Next check available in {remaining/60:.0f} minutes.", + "next_check_in_seconds": remaining, + }, indent=2) + + _last_status_check[run_id] = now + result = await _make_request("GET", f"/runs/{run_id}") + return json.dumps(result, indent=2) + + +async def rl_stop_training(run_id: str) -> str: + """ + Stop a running training job. + + Use this if: + - Metrics look bad or training is stagnant + - You want to try different settings + - You need to free up resources + + Args: + run_id: The run ID to stop + + Returns: + JSON string with stop confirmation + """ + result = await _make_request("POST", f"/runs/{run_id}/stop") + return json.dumps(result, indent=2) + + +async def rl_get_results(run_id: str) -> str: + """ + Get final results and metrics for a completed training run. + + Returns: + - Final metrics (reward, loss, accuracy) + - WandB run URL for detailed analysis + - Path to trained weights (tinker:// URL) + + Args: + run_id: The run ID to get results for + + Returns: + JSON string with final results and weights path + """ + result = await _make_request("GET", f"/runs/{run_id}/metrics") + return json.dumps(result, indent=2) + + +# ============================================================================ +# Inference Testing Tools +# ============================================================================ + +async def rl_test_inference( + prompts: List[str], + max_tokens: int = 256, + temperature: float = 1.0, +) -> str: + """ + Test inference + verifier on sample prompts WITHOUT full training. + + Use this to validate environments before committing to long training runs. + Tests: + - Data loading and formatting + - Model inference through Tinker + - Verifier/reward function logic + + NOTE: This still requires the RL API server to be running with + Tinker access for the Sample() method. + + Args: + prompts: List of test prompts to run through the environment + max_tokens: Maximum tokens to generate per prompt + temperature: Sampling temperature + + Returns: + JSON string with responses and verifier scores for each prompt + + TIP: Include prompts with known correct/incorrect answers to verify + the reward function is working correctly. + """ + result = await _make_request("POST", "/test/inference", { + "prompts": prompts, + "max_tokens": max_tokens, + "temperature": temperature, + }) + return json.dumps(result, indent=2) + + +# ============================================================================ +# Utility Tools +# ============================================================================ + +async def rl_list_runs() -> str: + """ + List all training runs (active and completed). + + Returns: + JSON string with list of runs and their status + """ + result = await _make_request("GET", "/runs") + return json.dumps(result, indent=2) + + +# ============================================================================ +# Requirements Check +# ============================================================================ + +def check_rl_api_keys() -> bool: + """ + Check if required API keys are available in environment variables. + + Required: + - TINKER_API_KEY: For Tinker training service + - WANDB_API_KEY: For metrics logging and fetching + + Returns: + bool: True if all required keys are set, False otherwise + """ + tinker_key = os.getenv("TINKER_API_KEY") + wandb_key = os.getenv("WANDB_API_KEY") + + return bool(tinker_key) and bool(wandb_key) + + +def get_missing_keys() -> List[str]: + """ + Get list of missing required API keys. + + Returns: + List of missing key names + """ + missing = [] + if not os.getenv("TINKER_API_KEY"): + missing.append("TINKER_API_KEY") + if not os.getenv("WANDB_API_KEY"): + missing.append("WANDB_API_KEY") + return missing + + +# ============================================================================ +# Debug/Status +# ============================================================================ + +async def rl_health_check() -> str: + """ + Check if the RL API server is running and accessible. + + Returns: + JSON string with server health status + """ + result = await _make_request("GET", "/health") + return json.dumps(result, indent=2) diff --git a/toolsets.py b/toolsets.py index 5d08731e..e4644251 100644 --- a/toolsets.py +++ b/toolsets.py @@ -90,6 +90,18 @@ TOOLSETS = { "includes": [] }, + "rl": { + "description": "RL training tools for running reinforcement learning on Tinker-Atropos", + "tools": [ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs" + ], + "includes": [] + }, + # Scenario-specific toolsets "debugging": { From f6574978de39c6ccae8a06d13ddabbb2c72c9ce1 Mon Sep 17 00:00:00 2001 From: teknium1 Date: Wed, 4 Feb 2026 09:36:51 -0800 Subject: [PATCH 2/6] Add RL training configuration and tools - Updated `.env.example` to include Tinker and WandB API keys for reinforcement learning training. - Enhanced `model_tools.py` to clarify configuration options and streamline the RL training process. - Expanded `README.md` with detailed instructions for setting up RL training using Tinker and WandB. - Modified `hermes_cli` files to integrate RL training tools and ensure proper configuration checks. - Improved `rl_training_tool.py` to reflect changes in training parameters and configuration management. --- .env.example | 18 +++++++++++ README.md | 56 ++++++++++++++++++++++++++++++++++ hermes_cli/config.py | 14 +++++++++ hermes_cli/setup.py | 49 ++++++++++++++++++++++++++++++ hermes_cli/status.py | 2 ++ model_tools.py | 32 ++++---------------- tools/rl_training_tool.py | 63 ++++++++++++++++----------------------- 7 files changed, 169 insertions(+), 65 deletions(-) diff --git a/.env.example b/.env.example index 98c5ea19..85ecf09d 100644 --- a/.env.example +++ b/.env.example @@ -165,3 +165,21 @@ IMAGE_TOOLS_DEBUG=false # CONTEXT_COMPRESSION_ENABLED=true # Enable auto-compression (default: true) # CONTEXT_COMPRESSION_THRESHOLD=0.85 # Compress at 85% of context limit # CONTEXT_COMPRESSION_MODEL=google/gemini-2.0-flash-001 # Fast model for summaries + +# ============================================================================= +# RL TRAINING (Tinker + Atropos) +# ============================================================================= +# Run reinforcement learning training on language models using the Tinker API. +# Requires the rl-server to be running (from tinker-atropos package). + +# Tinker API Key - RL training service +# Get at: https://tinker-console.thinkingmachines.ai/keys +TINKER_API_KEY= + +# Weights & Biases API Key - Experiment tracking and metrics +# Get at: https://wandb.ai/authorize +WANDB_API_KEY= + +# RL API Server URL (default: http://localhost:8080) +# Change if running the rl-server on a different host/port +# RL_API_URL=http://localhost:8080 diff --git a/README.md b/README.md index 8a999cb1..f49ae26a 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,7 @@ You need at least one LLM provider: | Web scraping | [Firecrawl](https://firecrawl.dev/) | `FIRECRAWL_API_KEY` | | Browser automation | [Browserbase](https://browserbase.com/) | `BROWSERBASE_API_KEY`, `BROWSERBASE_PROJECT_ID` | | Image generation | [FAL](https://fal.ai/) | `FAL_KEY` | +| RL Training | [Tinker](https://tinker-console.thinkingmachines.ai/) + [WandB](https://wandb.ai/) | `TINKER_API_KEY`, `WANDB_API_KEY` | | Messaging | Telegram, Discord | `TELEGRAM_BOT_TOKEN`, `DISCORD_BOT_TOKEN` | --- @@ -270,6 +271,61 @@ When enabled, you'll see messages like: See [docs/messaging.md](docs/messaging.md) for WhatsApp and advanced setup. +### šŸ¤– RL Training (Tinker + Atropos) + +Train language models with reinforcement learning using the Tinker API and Atropos framework. + +#### Requirements + +1. **API Keys:** Add to `~/.hermes/.env`: +```bash +TINKER_API_KEY=your-tinker-key # Get from https://tinker-console.thinkingmachines.ai/keys +WANDB_API_KEY=your-wandb-key # Get from https://wandb.ai/authorize +``` + +2. **Install tinker-atropos:** (in a separate directory) +```bash +cd ~/tinker-atropos +pip install -e . +``` + +3. **Start the RL API server:** +```bash +rl-server # Runs on port 8080 by default +``` + +#### Using RL Tools + +The agent can now use RL training tools: + +``` +You: Start training on GSM8k with group_size=16 + +Agent: I'll set up an RL training run on the GSM8k environment... +[Uses rl_list_environments, rl_select_environment, rl_edit_config, rl_start_training] +``` + +#### Available RL Tools + +| Tool | Description | +|------|-------------| +| `rl_list_environments` | List available RL environments | +| `rl_select_environment` | Select an environment for training | +| `rl_get_current_config` | View all configurable options | +| `rl_edit_config` | Change a configuration value | +| `rl_start_training` | Start a training run | +| `rl_check_status` | Check training progress | +| `rl_stop_training` | Stop a running training | +| `rl_get_results` | Fetch WandB metrics | + +#### Dedicated RL CLI + +For extended RL workflows with longer timeouts: + +```bash +python rl_cli.py --model "anthropic/claude-sonnet-4-20250514" +``` + ### ā° Scheduled Tasks (Cron) Schedule tasks to run automatically: diff --git a/hermes_cli/config.py b/hermes_cli/config.py index a0d98b6a..82ce6ae7 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -151,6 +151,20 @@ OPTIONAL_ENV_VARS = { "tools": ["image_generate"], "password": True, }, + "TINKER_API_KEY": { + "description": "Tinker API key for RL training", + "prompt": "Tinker API key", + "url": "https://tinker-console.thinkingmachines.ai/keys", + "tools": ["rl_start_training", "rl_check_status", "rl_stop_training"], + "password": True, + }, + "WANDB_API_KEY": { + "description": "Weights & Biases API key for experiment tracking", + "prompt": "WandB API key", + "url": "https://wandb.ai/authorize", + "tools": ["rl_get_results", "rl_check_status"], + "password": True, + }, "OPENAI_BASE_URL": { "description": "Custom OpenAI-compatible API endpoint URL", "prompt": "API base URL (e.g., https://api.example.com/v1)", diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index 06668d4e..83f42730 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -186,6 +186,14 @@ def _print_setup_summary(config: dict, hermes_home): else: tool_status.append(("Image Generation", False, "FAL_KEY")) + # Tinker + WandB (RL training) + if get_env_value('TINKER_API_KEY') and get_env_value('WANDB_API_KEY'): + tool_status.append(("RL Training (Tinker)", True, None)) + elif get_env_value('TINKER_API_KEY'): + tool_status.append(("RL Training (Tinker)", False, "WANDB_API_KEY")) + else: + tool_status.append(("RL Training (Tinker)", False, "TINKER_API_KEY")) + # Terminal (always available if system deps met) tool_status.append(("Terminal/Commands", True, None)) @@ -932,6 +940,47 @@ def run_setup_wizard(args): if api_key: save_env_value("FAL_KEY", api_key) print_success(" Configured āœ“") + print() + + # Tinker + WandB - RL Training + print_info("─" * 50) + print(color(" RL Training (Tinker + WandB)", Colors.CYAN)) + print_info(" Enables: rl_start_training, rl_check_status, rl_get_results tools") + print_info(" Use case: Run reinforcement learning training via Tinker API") + tinker_configured = get_env_value('TINKER_API_KEY') + wandb_configured = get_env_value('WANDB_API_KEY') + + if tinker_configured and wandb_configured: + print_success(" Status: Configured āœ“") + if prompt_yes_no(" Update RL training credentials?", False): + api_key = prompt(" Tinker API key", password=True) + if api_key: + save_env_value("TINKER_API_KEY", api_key) + wandb_key = prompt(" WandB API key", password=True) + if wandb_key: + save_env_value("WANDB_API_KEY", wandb_key) + print_success(" Updated") + else: + if tinker_configured: + print_warning(" Status: Tinker configured, WandB missing") + elif wandb_configured: + print_warning(" Status: WandB configured, Tinker missing") + else: + print_warning(" Status: Not configured (tools will be disabled)") + + if prompt_yes_no(" Set up RL Training?", False): + print_info(" Get Tinker key at: https://tinker-console.thinkingmachines.ai/keys") + print_info(" Get WandB key at: https://wandb.ai/authorize") + api_key = prompt(" Tinker API key", password=True) + if api_key: + save_env_value("TINKER_API_KEY", api_key) + wandb_key = prompt(" WandB API key", password=True) + if wandb_key: + save_env_value("WANDB_API_KEY", wandb_key) + if api_key and wandb_key: + print_success(" Configured āœ“") + else: + print_warning(" Partially configured (both keys required)") # ========================================================================= # Save config and show summary diff --git a/hermes_cli/status.py b/hermes_cli/status.py index 2d24bb50..bbbdc2af 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -74,6 +74,8 @@ def show_status(args): "Firecrawl": "FIRECRAWL_API_KEY", "Browserbase": "BROWSERBASE_API_KEY", "FAL": "FAL_KEY", + "Tinker": "TINKER_API_KEY", + "WandB": "WANDB_API_KEY", } for name, env_var in keys.items(): diff --git a/model_tools.py b/model_tools.py index ebabaf56..d84c3296 100644 --- a/model_tools.py +++ b/model_tools.py @@ -554,13 +554,13 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]: "type": "function", "function": { "name": "rl_edit_config", - "description": "Update a configuration field. Valid fields: group_size (int), max_token_length (int), total_steps (int), steps_per_eval (int), use_wandb (bool), wandb_name (str), max_num_workers (int).", + "description": "Update a configuration field. Use rl_get_current_config() first to see all available fields for the selected environment. Each environment has different configurable options. Infrastructure settings (tokenizer, URLs, lora_rank, learning_rate) are locked.", "parameters": { "type": "object", "properties": { "field": { "type": "string", - "description": "Name of the field to update" + "description": "Name of the field to update (get available fields from rl_get_current_config)" }, "value": { "description": "New value for the field" @@ -574,26 +574,10 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]: "type": "function", "function": { "name": "rl_start_training", - "description": "Start a new RL training run. WARNING: Training can take hours. Use rl_check_status() to monitor (30-minute intervals recommended). Test with rl_test_inference() first!", + "description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours. Test with rl_test_inference() first!", "parameters": { "type": "object", - "properties": { - "wandb_project": { - "type": "string", - "description": "WandB project name for logging", - "default": "rl-training" - }, - "lora_rank": { - "type": "integer", - "description": "LoRA rank for training", - "default": 32 - }, - "learning_rate": { - "type": "number", - "description": "Learning rate", - "default": 4e-5 - } - }, + "properties": {}, "required": [] } } @@ -1324,13 +1308,7 @@ def handle_rl_function_call( ) elif function_name == "rl_start_training": - return loop.run_until_complete( - rl_start_training( - wandb_project=function_args.get("wandb_project", "rl-training"), - lora_rank=function_args.get("lora_rank", 32), - learning_rate=function_args.get("learning_rate", 4e-5) - ) - ) + return loop.run_until_complete(rl_start_training()) elif function_name == "rl_check_status": return loop.run_until_complete( diff --git a/tools/rl_training_tool.py b/tools/rl_training_tool.py index 1b7401c1..7c40bc72 100644 --- a/tools/rl_training_tool.py +++ b/tools/rl_training_tool.py @@ -168,20 +168,22 @@ async def rl_get_current_config() -> str: """ Get the current environment configuration. - Returns only the fields that are safe to modify. Other fields - (tokenizer_name, rollout_server_url, etc.) are fixed by the system. + Returns all configurable fields for the selected environment. + Each environment may have different configuration options. - Available fields: - - group_size: Rollouts per prompt (4-16 typical) - - max_token_length: Max generation tokens (2048-16384) - - total_steps: Training steps (50-2000) - - steps_per_eval: Steps between evaluations - - use_wandb: Enable WandB logging + Fields are divided into: + - configurable_fields: Can be changed with rl_edit_config() + - locked_fields: Infrastructure settings that cannot be changed + + Common configurable fields include: + - group_size: Rollouts per prompt + - batch_size: Training batch size - wandb_name: WandB run name prefix - - max_num_workers: Concurrent workers (-1 = auto) + - system_prompt: Model instructions + - And any environment-specific options Returns: - JSON string with current config fields and their values + JSON string with configurable and locked fields """ result = await _make_request("GET", "/config") return json.dumps(result, indent=2) @@ -191,21 +193,15 @@ async def rl_edit_config(field: str, value: Any) -> str: """ Update a configuration field. - Only exposed fields can be modified. Validates field name and type. + Use rl_get_current_config() first to see available fields for the + selected environment. Each environment has different options. + + Locked fields (infrastructure settings) cannot be changed. Args: - field: Name of the field to update (e.g., "group_size", "total_steps") + field: Name of the field to update (from rl_get_current_config) value: New value for the field - Valid fields: - - group_size (int): Rollouts per prompt - - max_token_length (int): Max generation tokens - - total_steps (int): Training steps - - steps_per_eval (int): Eval frequency - - use_wandb (bool): Enable logging - - wandb_name (str): Run name prefix - - max_num_workers (int): Workers count - Returns: JSON string with updated config or error message """ @@ -217,37 +213,28 @@ async def rl_edit_config(field: str, value: Any) -> str: # Training Management Tools # ============================================================================ -async def rl_start_training( - wandb_project: str = "rl-training", - lora_rank: int = 32, - learning_rate: float = 4e-5, -) -> str: +async def rl_start_training() -> str: """ Start a new RL training run with the current environment and config. Requires an environment to be selected first using rl_select_environment(). + Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. - WARNING: Training runs can take hours to days. Use rl_check_status() to - monitor progress (recommended: check every 30 minutes at most). + Most training parameters are fixed (lora_rank=32, learning_rate=4e-5, etc.) + and cannot be changed. - Args: - wandb_project: WandB project name for logging - lora_rank: LoRA rank for training (default: 32) - learning_rate: Learning rate (default: 4e-5) + WARNING: Training runs take hours. Use rl_check_status() to monitor + progress (recommended: check every 30 minutes at most). Returns: JSON string with run_id and initial status TIP: Before starting training: 1. Test with rl_test_inference() to verify the environment works - 2. Start with fewer total_steps to validate the setup + 2. Configure group_size and batch_size appropriately 3. Monitor WandB metrics for reward/mean and percent_correct """ - result = await _make_request("POST", "/runs", { - "wandb_project": wandb_project, - "lora_rank": lora_rank, - "learning_rate": learning_rate, - }) + result = await _make_request("POST", "/runs", {}) return json.dumps(result, indent=2) From 12bbca95ecf4bbca5e3d4056526584ae3624e3c7 Mon Sep 17 00:00:00 2001 From: teknium1 Date: Wed, 4 Feb 2026 10:36:01 -0800 Subject: [PATCH 3/6] Add tinker-atropos submodule and update RL training tools - Added the tinker-atropos submodule for enhanced RL training capabilities. - Updated model_tools.py to reorder RL function definitions and improve descriptions. - Modified rl_cli.py to include checks for the tinker-atropos setup and provide user guidance. - Adjusted toolsets.py and __init__.py to reflect changes in RL function availability. - Enhanced rl_training_tool.py to manage training processes directly without a separate API server. --- .gitmodules | 3 + model_tools.py | 74 ++- rl_cli.py | 66 ++- tinker-atropos | 1 + tools/__init__.py | 6 +- tools/rl_training_tool.py | 1163 +++++++++++++++++++++++++++++++------ toolsets.py | 2 +- 7 files changed, 1059 insertions(+), 256 deletions(-) create mode 160000 tinker-atropos diff --git a/.gitmodules b/.gitmodules index f08f6745..6a494f4b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "mini-swe-agent"] path = mini-swe-agent url = https://github.com/SWE-agent/mini-swe-agent +[submodule "tinker-atropos"] + path = tinker-atropos + url = https://github.com/nousresearch/tinker-atropos diff --git a/model_tools.py b/model_tools.py index d84c3296..847e56ef 100644 --- a/model_tools.py +++ b/model_tools.py @@ -49,9 +49,8 @@ from tools.rl_training_tool import ( rl_check_status, rl_stop_training, rl_get_results, - rl_test_inference, rl_list_runs, - rl_health_check, + rl_test_inference, check_rl_api_keys, ) # Cronjob management tools (CLI-only) @@ -153,7 +152,7 @@ TOOLSET_REQUIREMENTS = { "rl_get_current_config", "rl_edit_config", "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", - "rl_test_inference", "rl_list_runs", + "rl_list_runs", "rl_test_inference", ], }, } @@ -574,7 +573,7 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]: "type": "function", "function": { "name": "rl_start_training", - "description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours. Test with rl_test_inference() first!", + "description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours.", "parameters": { "type": "object", "properties": {}, @@ -636,39 +635,39 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]: { "type": "function", "function": { - "name": "rl_test_inference", - "description": "Test inference + verifier on sample prompts WITHOUT full training. Use to validate environments before committing to long training runs. Tests data loading, inference, and verifier logic.", + "name": "rl_list_runs", + "description": "List all training runs (active and completed) with their status.", "parameters": { "type": "object", - "properties": { - "prompts": { - "type": "array", - "items": {"type": "string"}, - "description": "List of test prompts to run through the environment" - }, - "max_tokens": { - "type": "integer", - "description": "Maximum tokens to generate per prompt", - "default": 256 - }, - "temperature": { - "type": "number", - "description": "Sampling temperature", - "default": 1.0 - } - }, - "required": ["prompts"] + "properties": {}, + "required": [] } } }, { "type": "function", "function": { - "name": "rl_list_runs", - "description": "List all training runs (active and completed) with their status.", + "name": "rl_test_inference", + "description": "Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps Ɨ 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, inference parsing, and verifier logic. Use BEFORE training to catch issues.", "parameters": { "type": "object", - "properties": {}, + "properties": { + "num_steps": { + "type": "integer", + "description": "Number of steps to run (default: 3, recommended max for testing)", + "default": 3 + }, + "group_size": { + "type": "integer", + "description": "Completions per step (default: 16, like training)", + "default": 16 + }, + "models": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, zhipu-ai/glm-4-flash, minimax/minimax-m1" + } + }, "required": [] } } @@ -731,7 +730,7 @@ def get_all_tool_names() -> List[str]: "rl_get_current_config", "rl_edit_config", "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", - "rl_test_inference", "rl_list_runs" + "rl_list_runs" ]) return tool_names @@ -782,7 +781,6 @@ def get_toolset_for_tool(tool_name: str) -> str: "rl_check_status": "rl_tools", "rl_stop_training": "rl_tools", "rl_get_results": "rl_tools", - "rl_test_inference": "rl_tools", "rl_list_runs": "rl_tools", } @@ -900,7 +898,7 @@ def get_tool_definitions( "rl_get_current_config", "rl_edit_config", "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", - "rl_test_inference", "rl_list_runs" + "rl_list_runs" ] } legacy_tools = legacy_map.get(toolset_name, []) @@ -952,7 +950,7 @@ def get_tool_definitions( "rl_get_current_config", "rl_edit_config", "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", - "rl_test_inference", "rl_list_runs" + "rl_list_runs" ] } legacy_tools = legacy_map.get(toolset_name, []) @@ -1325,18 +1323,18 @@ def handle_rl_function_call( rl_get_results(run_id=function_args.get("run_id", "")) ) + elif function_name == "rl_list_runs": + return loop.run_until_complete(rl_list_runs()) + elif function_name == "rl_test_inference": return loop.run_until_complete( rl_test_inference( - prompts=function_args.get("prompts", []), - max_tokens=function_args.get("max_tokens", 256), - temperature=function_args.get("temperature", 1.0) + num_steps=function_args.get("num_steps", 3), + group_size=function_args.get("group_size", 16), + models=function_args.get("models"), ) ) - elif function_name == "rl_list_runs": - return loop.run_until_complete(rl_list_runs()) - return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False) @@ -1409,7 +1407,7 @@ def handle_function_call( "rl_get_current_config", "rl_edit_config", "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", - "rl_test_inference", "rl_list_runs" + "rl_list_runs" ]: return handle_rl_function_call(function_name, function_args) diff --git a/rl_cli.py b/rl_cli.py index cd76c91d..fe0eecfd 100644 --- a/rl_cli.py +++ b/rl_cli.py @@ -16,7 +16,6 @@ Usage: Environment Variables: TINKER_API_KEY: API key for Tinker service (required) WANDB_API_KEY: API key for WandB metrics (required) - RL_API_URL: URL of RL API server (default: http://localhost:8080) OPENROUTER_API_KEY: API key for OpenRouter (required for agent) """ @@ -38,7 +37,7 @@ if env_path.exists(): # Import agent and tools from run_agent import AIAgent from model_tools import get_tool_definitions, check_toolset_requirements -from tools.rl_training_tool import check_rl_api_keys, get_missing_keys, rl_health_check +from tools.rl_training_tool import check_rl_api_keys, get_missing_keys # ============================================================================ @@ -138,17 +137,21 @@ def check_requirements(): return True -async def check_rl_server(): - """Check if the RL API server is running.""" - try: - result = await rl_health_check() - import json - data = json.loads(result) - if "error" in data: - return False, data["error"] - return True, data - except Exception as e: - return False, str(e) +def check_tinker_atropos(): + """Check if tinker-atropos submodule is properly set up.""" + tinker_path = Path(__file__).parent / "tinker-atropos" + + if not tinker_path.exists(): + return False, "tinker-atropos submodule not found. Run: git submodule update --init" + + envs_path = tinker_path / "tinker_atropos" / "environments" + if not envs_path.exists(): + return False, f"environments directory not found at {envs_path}" + + env_files = list(envs_path.glob("*.py")) + env_files = [f for f in env_files if not f.name.startswith("_")] + + return True, {"path": str(tinker_path), "environments_count": len(env_files)} def list_environments_sync(): @@ -210,19 +213,27 @@ def main( print("šŸŽÆ RL Training Agent") print("=" * 60) - # Handle server check + # Handle setup check if check_server: - print("\nšŸ” Checking RL API server...") - ok, result = asyncio.run(check_rl_server()) + print("\nšŸ” Checking tinker-atropos setup...") + ok, result = check_tinker_atropos() if ok: - print("āœ… RL API server is running") - print(f" Environments discovered: {result.get('environments_discovered', 'unknown')}") - print(f" Current environment: {result.get('current_environment', 'none')}") - print(f" Active runs: {result.get('active_runs', 0)}") + print("āœ… tinker-atropos submodule found") + print(f" Path: {result.get('path')}") + print(f" Environments found: {result.get('environments_count', 0)}") + + # Also check API keys + missing = get_missing_keys() + if missing: + print(f"\nāš ļø Missing API keys: {', '.join(missing)}") + print(" Add them to ~/.hermes/.env") + else: + print("āœ… API keys configured") else: - print(f"āŒ RL API server not accessible: {result}") - print("\nTo start the server:") - print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") + print(f"āŒ tinker-atropos not set up: {result}") + print("\nTo set up:") + print(" git submodule update --init") + print(" pip install -e ./tinker-atropos") return # Handle environment listing @@ -238,8 +249,8 @@ def main( envs = data.get("environments", []) if not envs: print("No environments found.") - print("\nMake sure the RL API server is running:") - print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") + print("\nMake sure tinker-atropos is set up:") + print(" git submodule update --init") return for env in envs: @@ -254,8 +265,9 @@ def main( print("\nUse `rl_select_environment(name)` to select an environment for training.") except Exception as e: print(f"āŒ Error listing environments: {e}") - print("\nMake sure the RL API server is running:") - print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") + print("\nMake sure tinker-atropos is set up:") + print(" git submodule update --init") + print(" pip install -e ./tinker-atropos") return # Check requirements diff --git a/tinker-atropos b/tinker-atropos new file mode 160000 index 00000000..65f084ee --- /dev/null +++ b/tinker-atropos @@ -0,0 +1 @@ +Subproject commit 65f084ee8054a5d02aeac76e24ed60388511c82b diff --git a/tools/__init__.py b/tools/__init__.py index dd8bb4da..0b6bcdcc 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -105,9 +105,8 @@ from .rl_training_tool import ( rl_check_status, rl_stop_training, rl_get_results, - rl_test_inference, rl_list_runs, - rl_health_check, + rl_test_inference, check_rl_api_keys, get_missing_keys, ) @@ -178,9 +177,8 @@ __all__ = [ 'rl_check_status', 'rl_stop_training', 'rl_get_results', - 'rl_test_inference', 'rl_list_runs', - 'rl_health_check', + 'rl_test_inference', 'check_rl_api_keys', 'get_missing_keys', ] diff --git a/tools/rl_training_tool.py b/tools/rl_training_tool.py index 7c40bc72..3c257c4c 100644 --- a/tools/rl_training_tool.py +++ b/tools/rl_training_tool.py @@ -3,22 +3,18 @@ RL Training Tools Module This module provides tools for running RL training through Tinker-Atropos. -Communicates with the RL API server (rl_api_server.py) to manage: -- Environment discovery and selection -- Configuration management -- Training run lifecycle +Directly manages training processes without requiring a separate API server. + +Features: +- Environment discovery (AST-based scanning for BaseEnv subclasses) +- Configuration management with locked infrastructure settings +- Training run lifecycle via subprocess management - WandB metrics monitoring -- Inference-only testing Required environment variables: - TINKER_API_KEY: API key for Tinker service - WANDB_API_KEY: API key for Weights & Biases metrics -Optional environment variables: -- RL_API_URL: URL of the RL API server (default: http://localhost:8080) -- WANDB_ENTITY: WandB entity/team name -- WANDB_PROJECT: Default WandB project name - Usage: from tools.rl_training_tool import ( rl_list_environments, @@ -29,66 +25,429 @@ Usage: rl_check_status, rl_stop_training, rl_get_results, - rl_test_inference, ) """ +import ast +import asyncio +import importlib.util import json import os +import subprocess +import sys import time +import uuid +import yaml +from dataclasses import dataclass, field +from pathlib import Path from typing import Any, Dict, List, Optional -import aiohttp - # ============================================================================ -# Configuration +# Path Configuration # ============================================================================ -# Default RL API server URL (can be overridden via environment variable) -RL_API_URL = os.getenv("RL_API_URL", "http://localhost:8080") +# Path to tinker-atropos submodule (relative to hermes-agent root) +HERMES_ROOT = Path(__file__).parent.parent +TINKER_ATROPOS_ROOT = HERMES_ROOT / "tinker-atropos" +ENVIRONMENTS_DIR = TINKER_ATROPOS_ROOT / "tinker_atropos" / "environments" +CONFIGS_DIR = TINKER_ATROPOS_ROOT / "configs" +LOGS_DIR = TINKER_ATROPOS_ROOT / "logs" -# Rate limiting for status checks (30 minutes in seconds) -MIN_STATUS_CHECK_INTERVAL = 30 * 60 +# Ensure logs directory exists +LOGS_DIR.mkdir(exist_ok=True) + + +# ============================================================================ +# Locked Configuration (Infrastructure Settings) +# ============================================================================ + +# These fields cannot be changed by the model - they're tuned for our infrastructure +LOCKED_FIELDS = { + "env": { + "tokenizer_name": "Qwen/Qwen3-8B", + "rollout_server_url": "http://localhost:8000", + "use_wandb": True, + "max_token_length": 8192, + "max_num_workers": 2048, + "worker_timeout": 3600, + "total_steps": 2500, + "steps_per_eval": 25, + "max_batches_offpolicy": 3, + "inference_weight": 1.0, + "eval_limit_ratio": 0.1, + }, + "openai": [ + { + "model_name": "Qwen/Qwen3-8B", + "base_url": "http://localhost:8001/v1", + "api_key": "x", + "weight": 1.0, + "num_requests_for_eval": 256, + "timeout": 3600, + } + ], + "tinker": { + "lora_rank": 32, + "learning_rate": 0.00004, + "max_token_trainer_length": 9000, + "checkpoint_dir": "./temp/", + "save_checkpoint_interval": 25, + }, + "slurm": False, + "testing": False, +} + +LOCKED_FIELD_NAMES = set(LOCKED_FIELDS.get("env", {}).keys()) + + +# ============================================================================ +# State Management +# ============================================================================ + +@dataclass +class EnvironmentInfo: + """Information about a discovered environment.""" + name: str + class_name: str + file_path: str + description: str = "" + config_class: str = "BaseEnvConfig" + + +@dataclass +class RunState: + """State for a training run.""" + run_id: str + environment: str + config: Dict[str, Any] + status: str = "pending" # pending, starting, running, stopping, stopped, completed, failed + error_message: str = "" + wandb_project: str = "" + wandb_run_name: str = "" + start_time: float = 0.0 + # Process handles + api_process: Optional[subprocess.Popen] = None + trainer_process: Optional[subprocess.Popen] = None + env_process: Optional[subprocess.Popen] = None + + +# Global state +_environments: List[EnvironmentInfo] = [] +_current_env: Optional[str] = None +_current_config: Dict[str, Any] = {} +_env_config_cache: Dict[str, Dict[str, Dict[str, Any]]] = {} +_active_runs: Dict[str, RunState] = {} _last_status_check: Dict[str, float] = {} +# Rate limiting for status checks (30 minutes) +MIN_STATUS_CHECK_INTERVAL = 30 * 60 + # ============================================================================ -# Helper Functions +# Environment Discovery # ============================================================================ -async def _make_request( - method: str, - endpoint: str, - data: Optional[Dict] = None, - timeout: int = 30, -) -> Dict[str, Any]: - """Make an HTTP request to the RL API server.""" - url = f"{RL_API_URL}{endpoint}" +def _scan_environments() -> List[EnvironmentInfo]: + """ + Scan the environments directory for BaseEnv subclasses using AST. + """ + environments = [] - async with aiohttp.ClientSession() as session: + if not ENVIRONMENTS_DIR.exists(): + return environments + + for py_file in ENVIRONMENTS_DIR.glob("*.py"): + if py_file.name.startswith("_"): + continue + try: - if method == "GET": - async with session.get(url, timeout=timeout) as response: - if response.status == 200: - return await response.json() - else: - error_text = await response.text() - return {"error": f"HTTP {response.status}: {error_text}"} - elif method == "POST": - async with session.post(url, json=data, timeout=timeout) as response: - if response.status == 200: - return await response.json() - else: - error_text = await response.text() - return {"error": f"HTTP {response.status}: {error_text}"} - except aiohttp.ClientConnectorError: - return { - "error": f"Cannot connect to RL API server at {RL_API_URL}. " - "Make sure the server is running: " - "cd tinker-atropos && uvicorn rl_api_server:app --port 8080" - } + with open(py_file, "r") as f: + tree = ast.parse(f.read()) + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + # Check if class has BaseEnv as base + for base in node.bases: + base_name = "" + if isinstance(base, ast.Name): + base_name = base.id + elif isinstance(base, ast.Attribute): + base_name = base.attr + + if base_name == "BaseEnv": + # Extract name from class attribute if present + env_name = py_file.stem + description = "" + config_class = "BaseEnvConfig" + + for item in node.body: + if isinstance(item, ast.Assign): + for target in item.targets: + if isinstance(target, ast.Name): + if target.id == "name" and isinstance(item.value, ast.Constant): + env_name = item.value.value + elif target.id == "env_config_cls" and isinstance(item.value, ast.Name): + config_class = item.value.id + + # Get docstring + if isinstance(item, ast.Expr) and isinstance(item.value, ast.Constant): + if isinstance(item.value.value, str) and not description: + description = item.value.value.split("\n")[0].strip() + + environments.append(EnvironmentInfo( + name=env_name, + class_name=node.name, + file_path=str(py_file), + description=description or f"Environment from {py_file.name}", + config_class=config_class, + )) + break except Exception as e: - return {"error": f"Request failed: {str(e)}"} + print(f"Warning: Could not parse {py_file}: {e}") + + return environments + + +def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]: + """ + Dynamically import an environment and extract its config fields. + """ + try: + # Load the environment module + spec = importlib.util.spec_from_file_location("env_module", env_file_path) + module = importlib.util.module_from_spec(spec) + sys.modules["env_module"] = module + spec.loader.exec_module(module) + + # Find the BaseEnv subclass + env_class = None + for name, obj in vars(module).items(): + if isinstance(obj, type) and name != "BaseEnv": + if hasattr(obj, "config_init") and callable(getattr(obj, "config_init")): + env_class = obj + break + + if not env_class: + return {} + + # Call config_init to get the actual config + env_config, server_configs = env_class.config_init() + config_class = type(env_config) + + # Extract fields from the Pydantic model + fields = {} + for field_name, field_info in config_class.model_fields.items(): + field_type = field_info.annotation + default = field_info.default + description = field_info.description or "" + + is_locked = field_name in LOCKED_FIELD_NAMES + + # Convert type to string + type_name = getattr(field_type, "__name__", str(field_type)) + if hasattr(field_type, "__origin__"): + type_name = str(field_type) + + fields[field_name] = { + "type": type_name, + "default": default if default is not None else None, + "description": description, + "locked": is_locked, + "current_value": LOCKED_FIELDS.get("env", {}).get(field_name, default) if is_locked else default, + } + + return fields + + except Exception as e: + print(f"Warning: Could not introspect environment config: {e}") + return {} + + +def _initialize_environments(): + """Initialize environment list on first use.""" + global _environments + if not _environments: + _environments = _scan_environments() + + +# ============================================================================ +# Subprocess Management +# ============================================================================ + +async def _spawn_training_run(run_state: RunState, config_path: Path): + """ + Spawn the three processes needed for training: + 1. run-api (Atropos API server) + 2. launch_training.py (Tinker trainer + inference server) + 3. environment.py serve (the Atropos environment) + """ + run_id = run_state.run_id + + # Log file paths + api_log = LOGS_DIR / f"api_{run_id}.log" + trainer_log = LOGS_DIR / f"trainer_{run_id}.log" + env_log = LOGS_DIR / f"env_{run_id}.log" + + try: + # Step 1: Start the Atropos API server (run-api) + print(f"[{run_id}] Starting Atropos API server (run-api)...") + + api_log_file = open(api_log, "w") + run_state.api_process = subprocess.Popen( + ["run-api"], + stdout=api_log_file, + stderr=subprocess.STDOUT, + cwd=str(TINKER_ATROPOS_ROOT), + ) + + # Wait for API to start + await asyncio.sleep(5) + + if run_state.api_process.poll() is not None: + run_state.status = "failed" + run_state.error_message = f"API server exited with code {run_state.api_process.returncode}. Check {api_log}" + return + + print(f"[{run_id}] Atropos API server started") + + # Step 2: Start the Tinker trainer + print(f"[{run_id}] Starting Tinker trainer: launch_training.py --config {config_path}") + + trainer_log_file = open(trainer_log, "w") + run_state.trainer_process = subprocess.Popen( + ["python", "launch_training.py", "--config", str(config_path)], + stdout=trainer_log_file, + stderr=subprocess.STDOUT, + cwd=str(TINKER_ATROPOS_ROOT), + env={**os.environ, "TINKER_API_KEY": os.getenv("TINKER_API_KEY", "")}, + ) + + # Wait for trainer to initialize (it starts FastAPI inference server on 8001) + print(f"[{run_id}] Waiting 30 seconds for trainer to initialize...") + await asyncio.sleep(30) + + if run_state.trainer_process.poll() is not None: + run_state.status = "failed" + run_state.error_message = f"Trainer exited with code {run_state.trainer_process.returncode}. Check {trainer_log}" + if run_state.api_process: + run_state.api_process.terminate() + return + + print(f"[{run_id}] Trainer started, inference server on port 8001") + + # Step 3: Start the environment + print(f"[{run_id}] Waiting 90 more seconds before starting environment...") + await asyncio.sleep(90) + + # Find the environment file + env_info = None + for env in _environments: + if env.name == run_state.environment: + env_info = env + break + + if not env_info: + run_state.status = "failed" + run_state.error_message = f"Environment '{run_state.environment}' not found" + return + + print(f"[{run_id}] Starting environment: {env_info.file_path} serve") + + env_log_file = open(env_log, "w") + run_state.env_process = subprocess.Popen( + ["python", str(env_info.file_path), "serve", "--config", str(config_path)], + stdout=env_log_file, + stderr=subprocess.STDOUT, + cwd=str(TINKER_ATROPOS_ROOT), + ) + + # Wait for environment to connect + await asyncio.sleep(10) + + if run_state.env_process.poll() is not None: + run_state.status = "failed" + run_state.error_message = f"Environment exited with code {run_state.env_process.returncode}. Check {env_log}" + if run_state.trainer_process: + run_state.trainer_process.terminate() + if run_state.api_process: + run_state.api_process.terminate() + return + + run_state.status = "running" + run_state.start_time = time.time() + print(f"[{run_id}] Training run started successfully!") + + # Start background monitoring + asyncio.create_task(_monitor_training_run(run_state)) + + except Exception as e: + run_state.status = "failed" + run_state.error_message = str(e) + _stop_training_run(run_state) + + +async def _monitor_training_run(run_state: RunState): + """Background task to monitor a training run.""" + while run_state.status == "running": + await asyncio.sleep(30) # Check every 30 seconds + + # Check if any process has died + if run_state.env_process and run_state.env_process.poll() is not None: + exit_code = run_state.env_process.returncode + if exit_code == 0: + run_state.status = "completed" + else: + run_state.status = "failed" + run_state.error_message = f"Environment process exited with code {exit_code}" + _stop_training_run(run_state) + break + + if run_state.trainer_process and run_state.trainer_process.poll() is not None: + exit_code = run_state.trainer_process.returncode + if exit_code == 0: + run_state.status = "completed" + else: + run_state.status = "failed" + run_state.error_message = f"Trainer process exited with code {exit_code}" + _stop_training_run(run_state) + break + + if run_state.api_process and run_state.api_process.poll() is not None: + run_state.status = "failed" + run_state.error_message = f"API server exited unexpectedly" + _stop_training_run(run_state) + break + + +def _stop_training_run(run_state: RunState): + """Stop all processes for a training run.""" + # Stop in reverse order: env -> trainer -> api + if run_state.env_process and run_state.env_process.poll() is None: + print(f"[{run_state.run_id}] Stopping environment process...") + run_state.env_process.terminate() + try: + run_state.env_process.wait(timeout=10) + except subprocess.TimeoutExpired: + run_state.env_process.kill() + + if run_state.trainer_process and run_state.trainer_process.poll() is None: + print(f"[{run_state.run_id}] Stopping trainer process...") + run_state.trainer_process.terminate() + try: + run_state.trainer_process.wait(timeout=10) + except subprocess.TimeoutExpired: + run_state.trainer_process.kill() + + if run_state.api_process and run_state.api_process.poll() is None: + print(f"[{run_state.run_id}] Stopping API server...") + run_state.api_process.terminate() + try: + run_state.api_process.wait(timeout=10) + except subprocess.TimeoutExpired: + run_state.api_process.kill() + + if run_state.status == "running": + run_state.status = "stopped" # ============================================================================ @@ -113,20 +472,23 @@ async def rl_list_environments() -> str: 2. Study how they load datasets, define verifiers, and structure rewards 3. Inspect HuggingFace datasets to understand data formats 4. Copy an existing environment as a template - 5. Test with rl_test_inference before running full training Returns: - JSON string with list of environments or error message + JSON string with list of environments """ - result = await _make_request("GET", "/environments") + _initialize_environments() - if "error" in result: - return json.dumps(result, indent=2) - - # Add helpful tips to the response response = { - "environments": result, - "count": len(result), + "environments": [ + { + "name": env.name, + "class_name": env.class_name, + "file_path": env.file_path, + "description": env.description, + } + for env in _environments + ], + "count": len(_environments), "tips": [ "Use rl_select_environment(name) to select an environment", "Read the file_path with file tools to understand how each environment works", @@ -141,23 +503,58 @@ async def rl_select_environment(name: str) -> str: """ Select an RL environment for training. - This loads the environment's default configuration into the config state. - After selecting, use rl_get_current_config() to see the configuration + This loads the environment's configuration fields into memory. + After selecting, use rl_get_current_config() to see all configurable options and rl_edit_config() to modify specific fields. Args: name: Name of the environment to select (from rl_list_environments) Returns: - JSON string with selection result, file path, and current config + JSON string with selection result, file path, and configurable field count - TIP: Read the returned file_path to understand how the environment works: - - How it loads data (load_dataset calls) - - How it verifies answers (score_answer method) - - What prompts it uses (system_prompt, get_next_item) + TIP: Read the returned file_path to understand how the environment works. """ - result = await _make_request("POST", f"/environments/{name}/select") - return json.dumps(result, indent=2) + global _current_env, _current_config, _env_config_cache + + _initialize_environments() + + env_info = None + for env in _environments: + if env.name == name: + env_info = env + break + + if not env_info: + return json.dumps({ + "error": f"Environment '{name}' not found", + "available": [e.name for e in _environments], + }, indent=2) + + _current_env = name + + # Dynamically discover config fields + config_fields = _get_env_config_fields(env_info.file_path) + _env_config_cache[name] = config_fields + + # Initialize current config with defaults for non-locked fields + _current_config = {} + for field_name, field_info in config_fields.items(): + if not field_info.get("locked", False): + _current_config[field_name] = field_info.get("default") + + configurable_count = sum(1 for f in config_fields.values() if not f.get("locked", False)) + locked_count = sum(1 for f in config_fields.values() if f.get("locked", False)) + + return json.dumps({ + "message": f"Selected environment: {name}", + "environment": name, + "file_path": env_info.file_path, + "configurable_fields": configurable_count, + "locked_fields": locked_count, + "config": _current_config, + "tip": f"Use rl_get_current_config() to see all {configurable_count} configurable fields.", + }, indent=2) # ============================================================================ @@ -175,18 +572,40 @@ async def rl_get_current_config() -> str: - configurable_fields: Can be changed with rl_edit_config() - locked_fields: Infrastructure settings that cannot be changed - Common configurable fields include: - - group_size: Rollouts per prompt - - batch_size: Training batch size - - wandb_name: WandB run name prefix - - system_prompt: Model instructions - - And any environment-specific options - Returns: JSON string with configurable and locked fields """ - result = await _make_request("GET", "/config") - return json.dumps(result, indent=2) + if not _current_env: + return json.dumps({ + "error": "No environment selected. Use rl_select_environment(name) first.", + }, indent=2) + + config_fields = _env_config_cache.get(_current_env, {}) + + configurable = [] + locked = [] + + for field_name, field_info in config_fields.items(): + field_data = { + "name": field_name, + "type": field_info.get("type", "unknown"), + "default": field_info.get("default"), + "description": field_info.get("description", ""), + "current_value": _current_config.get(field_name, field_info.get("default")), + } + + if field_info.get("locked", False): + field_data["locked_value"] = LOCKED_FIELDS.get("env", {}).get(field_name) + locked.append(field_data) + else: + configurable.append(field_data) + + return json.dumps({ + "environment": _current_env, + "configurable_fields": configurable, + "locked_fields": locked, + "tip": "Use rl_edit_config(field, value) to change any configurable field.", + }, indent=2) async def rl_edit_config(field: str, value: Any) -> str: @@ -205,8 +624,36 @@ async def rl_edit_config(field: str, value: Any) -> str: Returns: JSON string with updated config or error message """ - result = await _make_request("POST", "/config", {"field": field, "value": value}) - return json.dumps(result, indent=2) + global _current_config + + if not _current_env: + return json.dumps({ + "error": "No environment selected. Use rl_select_environment(name) first.", + }, indent=2) + + config_fields = _env_config_cache.get(_current_env, {}) + + if field not in config_fields: + return json.dumps({ + "error": f"Unknown field '{field}'", + "available_fields": list(config_fields.keys()), + }, indent=2) + + field_info = config_fields[field] + if field_info.get("locked", False): + return json.dumps({ + "error": f"Field '{field}' is locked and cannot be changed", + "locked_value": LOCKED_FIELDS.get("env", {}).get(field), + }, indent=2) + + _current_config[field] = value + + return json.dumps({ + "message": f"Updated {field} = {value}", + "field": field, + "value": value, + "config": _current_config, + }, indent=2) # ============================================================================ @@ -218,24 +665,106 @@ async def rl_start_training() -> str: Start a new RL training run with the current environment and config. Requires an environment to be selected first using rl_select_environment(). - Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. + Use rl_edit_config() to adjust configuration before starting. - Most training parameters are fixed (lora_rank=32, learning_rate=4e-5, etc.) - and cannot be changed. + This spawns three processes: + 1. run-api (Atropos trajectory API) + 2. launch_training.py (Tinker trainer + inference server) + 3. environment.py serve (the selected environment) WARNING: Training runs take hours. Use rl_check_status() to monitor progress (recommended: check every 30 minutes at most). Returns: JSON string with run_id and initial status - - TIP: Before starting training: - 1. Test with rl_test_inference() to verify the environment works - 2. Configure group_size and batch_size appropriately - 3. Monitor WandB metrics for reward/mean and percent_correct """ - result = await _make_request("POST", "/runs", {}) - return json.dumps(result, indent=2) + global _active_runs + + if not _current_env: + return json.dumps({ + "error": "No environment selected. Use rl_select_environment(name) first.", + }, indent=2) + + # Check API keys + if not os.getenv("TINKER_API_KEY"): + return json.dumps({ + "error": "TINKER_API_KEY not set. Add it to ~/.hermes/.env", + }, indent=2) + + # Find environment file + env_info = None + for env in _environments: + if env.name == _current_env: + env_info = env + break + + if not env_info or not Path(env_info.file_path).exists(): + return json.dumps({ + "error": f"Environment file not found for '{_current_env}'", + }, indent=2) + + # Generate run ID + run_id = str(uuid.uuid4())[:8] + + # Create config YAML + CONFIGS_DIR.mkdir(exist_ok=True) + config_path = CONFIGS_DIR / f"run_{run_id}.yaml" + + # Start with locked config as base + import copy + run_config = copy.deepcopy(LOCKED_FIELDS) + + if "env" not in run_config: + run_config["env"] = {} + + # Apply configurable fields + for field_name, value in _current_config.items(): + if value is not None and value != "": + run_config["env"][field_name] = value + + # Set WandB settings + wandb_project = _current_config.get("wandb_project", "atropos-tinker") + if "tinker" not in run_config: + run_config["tinker"] = {} + run_config["tinker"]["wandb_project"] = wandb_project + run_config["tinker"]["wandb_run_name"] = f"{_current_env}-{run_id}" + + if "wandb_name" in _current_config and _current_config["wandb_name"]: + run_config["env"]["wandb_name"] = _current_config["wandb_name"] + + with open(config_path, "w") as f: + yaml.dump(run_config, f, default_flow_style=False) + + # Create run state + run_state = RunState( + run_id=run_id, + environment=_current_env, + config=_current_config.copy(), + status="starting", + wandb_project=wandb_project, + wandb_run_name=f"{_current_env}-{run_id}", + ) + + _active_runs[run_id] = run_state + + # Start training in background + asyncio.create_task(_spawn_training_run(run_state, config_path)) + + return json.dumps({ + "run_id": run_id, + "status": "starting", + "environment": _current_env, + "config": _current_config, + "wandb_project": wandb_project, + "wandb_run_name": f"{_current_env}-{run_id}", + "config_path": str(config_path), + "logs": { + "api": str(LOGS_DIR / f"api_{run_id}.log"), + "trainer": str(LOGS_DIR / f"trainer_{run_id}.log"), + "env": str(LOGS_DIR / f"env_{run_id}.log"), + }, + "message": "Training starting. Use rl_check_status(run_id) to monitor (recommended: every 30 minutes).", + }, indent=2) async def rl_check_status(run_id: str) -> str: @@ -245,19 +774,11 @@ async def rl_check_status(run_id: str) -> str: RATE LIMITED: For long-running training, this function enforces a minimum 30-minute interval between checks for the same run_id. - Fetches latest metrics from WandB if available: - - step: Current training step - - state: Run state (running, finished, crashed) - - reward_mean: Average reward across batches - - loss: Training loss - - percent_correct: Training accuracy - - eval_percent_correct: Evaluation accuracy - Args: run_id: The run ID returned by rl_start_training() Returns: - JSON string with run status and metrics, or rate limit message + JSON string with run status and metrics """ global _last_status_check @@ -275,7 +796,65 @@ async def rl_check_status(run_id: str) -> str: }, indent=2) _last_status_check[run_id] = now - result = await _make_request("GET", f"/runs/{run_id}") + + if run_id not in _active_runs: + return json.dumps({ + "error": f"Run '{run_id}' not found", + "active_runs": list(_active_runs.keys()), + }, indent=2) + + run_state = _active_runs[run_id] + + # Check process status + processes = { + "api": run_state.api_process.poll() if run_state.api_process else None, + "trainer": run_state.trainer_process.poll() if run_state.trainer_process else None, + "env": run_state.env_process.poll() if run_state.env_process else None, + } + + running_time = time.time() - run_state.start_time if run_state.start_time else 0 + + result = { + "run_id": run_id, + "status": run_state.status, + "environment": run_state.environment, + "running_time_minutes": running_time / 60, + "processes": { + name: "running" if code is None else f"exited ({code})" + for name, code in processes.items() + }, + "wandb_project": run_state.wandb_project, + "wandb_run_name": run_state.wandb_run_name, + "logs": { + "api": str(LOGS_DIR / f"api_{run_id}.log"), + "trainer": str(LOGS_DIR / f"trainer_{run_id}.log"), + "env": str(LOGS_DIR / f"env_{run_id}.log"), + }, + } + + if run_state.error_message: + result["error"] = run_state.error_message + + # Try to get WandB metrics if available + try: + import wandb + api = wandb.Api() + runs = api.runs( + f"{os.getenv('WANDB_ENTITY', 'nousresearch')}/{run_state.wandb_project}", + filters={"display_name": run_state.wandb_run_name} + ) + if runs: + wandb_run = runs[0] + result["wandb_url"] = wandb_run.url + result["metrics"] = { + "step": wandb_run.summary.get("_step", 0), + "reward_mean": wandb_run.summary.get("train/reward_mean"), + "percent_correct": wandb_run.summary.get("train/percent_correct"), + "eval_percent_correct": wandb_run.summary.get("eval/percent_correct"), + } + except Exception as e: + result["wandb_error"] = str(e) + return json.dumps(result, indent=2) @@ -283,84 +862,78 @@ async def rl_stop_training(run_id: str) -> str: """ Stop a running training job. - Use this if: - - Metrics look bad or training is stagnant - - You want to try different settings - - You need to free up resources - Args: run_id: The run ID to stop Returns: JSON string with stop confirmation """ - result = await _make_request("POST", f"/runs/{run_id}/stop") - return json.dumps(result, indent=2) + if run_id not in _active_runs: + return json.dumps({ + "error": f"Run '{run_id}' not found", + "active_runs": list(_active_runs.keys()), + }, indent=2) + + run_state = _active_runs[run_id] + + if run_state.status not in ("running", "starting"): + return json.dumps({ + "message": f"Run '{run_id}' is not running (status: {run_state.status})", + }, indent=2) + + _stop_training_run(run_state) + + return json.dumps({ + "message": f"Stopped training run '{run_id}'", + "run_id": run_id, + "status": run_state.status, + }, indent=2) async def rl_get_results(run_id: str) -> str: """ - Get final results and metrics for a completed training run. - - Returns: - - Final metrics (reward, loss, accuracy) - - WandB run URL for detailed analysis - - Path to trained weights (tinker:// URL) + Get final results and metrics for a training run. Args: run_id: The run ID to get results for Returns: - JSON string with final results and weights path + JSON string with final results """ - result = await _make_request("GET", f"/runs/{run_id}/metrics") + if run_id not in _active_runs: + return json.dumps({ + "error": f"Run '{run_id}' not found", + }, indent=2) + + run_state = _active_runs[run_id] + + result = { + "run_id": run_id, + "status": run_state.status, + "environment": run_state.environment, + "wandb_project": run_state.wandb_project, + "wandb_run_name": run_state.wandb_run_name, + } + + # Get WandB metrics + try: + import wandb + api = wandb.Api() + runs = api.runs( + f"{os.getenv('WANDB_ENTITY', 'nousresearch')}/{run_state.wandb_project}", + filters={"display_name": run_state.wandb_run_name} + ) + if runs: + wandb_run = runs[0] + result["wandb_url"] = wandb_run.url + result["final_metrics"] = dict(wandb_run.summary) + result["history"] = [dict(row) for row in wandb_run.history(samples=10)] + except Exception as e: + result["wandb_error"] = str(e) + return json.dumps(result, indent=2) -# ============================================================================ -# Inference Testing Tools -# ============================================================================ - -async def rl_test_inference( - prompts: List[str], - max_tokens: int = 256, - temperature: float = 1.0, -) -> str: - """ - Test inference + verifier on sample prompts WITHOUT full training. - - Use this to validate environments before committing to long training runs. - Tests: - - Data loading and formatting - - Model inference through Tinker - - Verifier/reward function logic - - NOTE: This still requires the RL API server to be running with - Tinker access for the Sample() method. - - Args: - prompts: List of test prompts to run through the environment - max_tokens: Maximum tokens to generate per prompt - temperature: Sampling temperature - - Returns: - JSON string with responses and verifier scores for each prompt - - TIP: Include prompts with known correct/incorrect answers to verify - the reward function is working correctly. - """ - result = await _make_request("POST", "/test/inference", { - "prompts": prompts, - "max_tokens": max_tokens, - "temperature": temperature, - }) - return json.dumps(result, indent=2) - - -# ============================================================================ -# Utility Tools -# ============================================================================ - async def rl_list_runs() -> str: """ List all training runs (active and completed). @@ -368,8 +941,252 @@ async def rl_list_runs() -> str: Returns: JSON string with list of runs and their status """ - result = await _make_request("GET", "/runs") - return json.dumps(result, indent=2) + runs = [] + for run_id, run_state in _active_runs.items(): + runs.append({ + "run_id": run_id, + "environment": run_state.environment, + "status": run_state.status, + "wandb_run_name": run_state.wandb_run_name, + }) + + return json.dumps({ + "runs": runs, + "count": len(runs), + }, indent=2) + + +# ============================================================================ +# Inference Testing (via Atropos `process` mode with OpenRouter) +# ============================================================================ + +# Test models at different scales for robustness testing +TEST_MODELS = [ + {"id": "qwen/qwen3-8b", "name": "Qwen3 8B", "scale": "small"}, + {"id": "zhipu-ai/glm-4-flash", "name": "GLM-4 Flash", "scale": "medium"}, + {"id": "minimax/minimax-m1", "name": "MiniMax M1", "scale": "large"}, +] + +# Default test parameters - quick but representative +DEFAULT_NUM_STEPS = 3 # Number of steps (items) to test +DEFAULT_GROUP_SIZE = 16 # Completions per item (like training) + + +async def rl_test_inference( + num_steps: int = DEFAULT_NUM_STEPS, + group_size: int = DEFAULT_GROUP_SIZE, + models: Optional[List[str]] = None, +) -> str: + """ + Quick inference test for any environment using Atropos's `process` mode. + + Runs a few steps of inference + scoring to validate: + - Environment loads correctly + - Prompt construction works + - Inference parsing is robust (tested with multiple model scales) + - Verifier/scoring logic works + + Default: 3 steps Ɨ 16 completions = 48 total rollouts per model. + Tests 3 models = 144 total rollouts. Quick sanity check. + + Test models (varying intelligence levels for robustness): + - qwen/qwen3-8b (small) + - zhipu-ai/glm-4-flash (medium) + - minimax/minimax-m1 (large) + + Args: + num_steps: Steps to run (default: 3, max recommended for testing) + group_size: Completions per step (default: 16, like training) + models: Optional model IDs to test. If None, uses all 3 test models. + + Returns: + JSON with results per model: steps_tested, accuracy, scores + """ + if not _current_env: + return json.dumps({ + "error": "No environment selected. Use rl_select_environment(name) first.", + }, indent=2) + + api_key = os.getenv("OPENROUTER_API_KEY") + if not api_key: + return json.dumps({ + "error": "OPENROUTER_API_KEY not set. Required for inference testing.", + }, indent=2) + + # Find environment info + env_info = None + for env in _environments: + if env.name == _current_env: + env_info = env + break + + if not env_info: + return json.dumps({ + "error": f"Environment '{_current_env}' not found", + }, indent=2) + + # Determine which models to test + if models: + test_models = [m for m in TEST_MODELS if m["id"] in models] + if not test_models: + test_models = [{"id": m, "name": m, "scale": "custom"} for m in models] + else: + test_models = TEST_MODELS + + # Calculate total rollouts for logging + total_rollouts_per_model = num_steps * group_size + total_rollouts = total_rollouts_per_model * len(test_models) + + results = { + "environment": _current_env, + "environment_file": env_info.file_path, + "test_config": { + "num_steps": num_steps, + "group_size": group_size, + "rollouts_per_model": total_rollouts_per_model, + "total_rollouts": total_rollouts, + }, + "models_tested": [], + } + + # Create output directory for test results + test_output_dir = LOGS_DIR / "inference_tests" + test_output_dir.mkdir(exist_ok=True) + + for model_info in test_models: + model_id = model_info["id"] + model_safe_name = model_id.replace("/", "_") + + print(f"\n{'='*60}") + print(f"Testing with {model_info['name']} ({model_id})") + print(f"{'='*60}") + + # Output file for this test run + output_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.jsonl" + + # Build the process command using Atropos's built-in CLI + # This runs the environment's actual code with OpenRouter as the inference backend + cmd = [ + "python", env_info.file_path, "process", + "--env.total_steps", str(num_steps), + "--env.group_size", str(group_size), + "--env.use_wandb", "false", + "--env.data_path_to_save_groups", str(output_file), + "--openai.base_url", "https://openrouter.ai/api/v1", + "--openai.api_key", api_key, + "--openai.model_name", model_id, + ] + + print(f"Running: python {Path(env_info.file_path).name} process ...") + print(f" {num_steps} steps Ɨ {group_size} completions = {total_rollouts_per_model} rollouts") + + model_results = { + "model": model_id, + "name": model_info["name"], + "scale": model_info["scale"], + "output_file": str(output_file), + "steps": [], + "steps_tested": 0, + "total_completions": 0, + "correct_completions": 0, + } + + try: + # Run the process command + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=str(TINKER_ATROPOS_ROOT), + ) + + stdout, stderr = await asyncio.wait_for( + process.communicate(), + timeout=600, # 10 minute timeout per model + ) + + if process.returncode != 0: + model_results["error"] = f"Process exited with code {process.returncode}" + model_results["stderr"] = stderr.decode()[-1000:] + print(f" Error: {model_results['error']}") + else: + print(f" Process completed successfully") + + # Parse the output JSONL file + if output_file.exists(): + # Read JSONL file (one JSON object per line = one step) + with open(output_file, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + item = json.loads(line) + scores = item.get("scores", []) + model_results["steps_tested"] += 1 + model_results["total_completions"] += len(scores) + correct = sum(1 for s in scores if s > 0) + model_results["correct_completions"] += correct + + model_results["steps"].append({ + "step": model_results["steps_tested"], + "completions": len(scores), + "correct": correct, + "scores": scores, + }) + except json.JSONDecodeError: + continue + + print(f" Completed {model_results['steps_tested']} steps") + else: + model_results["error"] = f"Output file not created: {output_file}" + + except asyncio.TimeoutError: + model_results["error"] = "Process timed out after 10 minutes" + print(f" Timeout!") + except Exception as e: + model_results["error"] = str(e) + print(f" Error: {e}") + + # Calculate stats + if model_results["total_completions"] > 0: + model_results["accuracy"] = round( + model_results["correct_completions"] / model_results["total_completions"], 3 + ) + else: + model_results["accuracy"] = 0 + + if model_results["steps_tested"] > 0: + steps_with_correct = sum(1 for s in model_results["steps"] if s.get("correct", 0) > 0) + model_results["steps_with_correct"] = steps_with_correct + model_results["step_success_rate"] = round( + steps_with_correct / model_results["steps_tested"], 3 + ) + else: + model_results["steps_with_correct"] = 0 + model_results["step_success_rate"] = 0 + + print(f" Results: {model_results['correct_completions']}/{model_results['total_completions']} correct") + print(f" Accuracy: {model_results['accuracy']:.1%}") + + results["models_tested"].append(model_results) + + # Overall summary + working_models = [m for m in results["models_tested"] if m.get("steps_tested", 0) > 0] + + results["summary"] = { + "steps_requested": num_steps, + "models_tested": len(test_models), + "models_succeeded": len(working_models), + "best_model": max(working_models, key=lambda x: x.get("accuracy", 0))["model"] if working_models else None, + "avg_accuracy": round( + sum(m.get("accuracy", 0) for m in working_models) / len(working_models), 3 + ) if working_models else 0, + "environment_working": len(working_models) > 0, + "output_directory": str(test_output_dir), + } + + return json.dumps(results, indent=2) # ============================================================================ @@ -378,27 +1195,16 @@ async def rl_list_runs() -> str: def check_rl_api_keys() -> bool: """ - Check if required API keys are available in environment variables. - - Required: - - TINKER_API_KEY: For Tinker training service - - WANDB_API_KEY: For metrics logging and fetching - - Returns: - bool: True if all required keys are set, False otherwise + Check if required API keys are available. """ tinker_key = os.getenv("TINKER_API_KEY") wandb_key = os.getenv("WANDB_API_KEY") - return bool(tinker_key) and bool(wandb_key) def get_missing_keys() -> List[str]: """ Get list of missing required API keys. - - Returns: - List of missing key names """ missing = [] if not os.getenv("TINKER_API_KEY"): @@ -406,18 +1212,3 @@ def get_missing_keys() -> List[str]: if not os.getenv("WANDB_API_KEY"): missing.append("WANDB_API_KEY") return missing - - -# ============================================================================ -# Debug/Status -# ============================================================================ - -async def rl_health_check() -> str: - """ - Check if the RL API server is running and accessible. - - Returns: - JSON string with server health status - """ - result = await _make_request("GET", "/health") - return json.dumps(result, indent=2) diff --git a/toolsets.py b/toolsets.py index e4644251..abd6192a 100644 --- a/toolsets.py +++ b/toolsets.py @@ -97,7 +97,7 @@ TOOLSETS = { "rl_get_current_config", "rl_edit_config", "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", - "rl_test_inference", "rl_list_runs" + "rl_list_runs", "rl_test_inference" ], "includes": [] }, From 3c0d0dba49f99da4b4e363545dfe1e2fac0417e6 Mon Sep 17 00:00:00 2001 From: teknium1 Date: Wed, 4 Feb 2026 13:57:59 -0800 Subject: [PATCH 4/6] Update RL tools and enhance configuration management - Modified `model_tools.py` to update default model IDs and add new RL function `rl_test_inference`. - Enhanced `README.md` with installation instructions for submodules and updated API key usage. - Improved `rl_cli.py` to load configuration from `~/.hermes/config.yaml` and set terminal working directory for RL tools. - Updated `run_agent.py` to handle empty string arguments as empty objects for better JSON validation. - Refined installation scripts to ensure submodules are cloned and installed correctly, enhancing setup experience. --- README.md | 23 ++++--- model_tools.py | 10 ++-- rl_cli.py | 91 +++++++++++++++++++++++++--- run_agent.py | 8 ++- scripts/install.ps1 | 42 ++++++++++++- scripts/install.sh | 34 +++++++++-- tools/rl_training_tool.py | 122 +++++++++++++++++++++++++++++++------- 7 files changed, 274 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index f49ae26a..a1673c91 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ irm https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/ins ``` The installer will: -- Clone to `~/.hermes-agent` +- Clone to `~/.hermes-agent` (with submodules: mini-swe-agent, tinker-atropos) - Create a virtual environment - Install all dependencies - Run the interactive setup wizard @@ -281,18 +281,10 @@ Train language models with reinforcement learning using the Tinker API and Atrop ```bash TINKER_API_KEY=your-tinker-key # Get from https://tinker-console.thinkingmachines.ai/keys WANDB_API_KEY=your-wandb-key # Get from https://wandb.ai/authorize +OPENROUTER_API_KEY=your-key # Optional: for rl_test_inference ``` -2. **Install tinker-atropos:** (in a separate directory) -```bash -cd ~/tinker-atropos -pip install -e . -``` - -3. **Start the RL API server:** -```bash -rl-server # Runs on port 8080 by default -``` +2. **That's it!** tinker-atropos is included as a submodule - no separate installation needed. #### Using RL Tools @@ -313,10 +305,12 @@ Agent: I'll set up an RL training run on the GSM8k environment... | `rl_select_environment` | Select an environment for training | | `rl_get_current_config` | View all configurable options | | `rl_edit_config` | Change a configuration value | +| `rl_test_inference` | Test environment with OpenRouter (pre-training validation) | | `rl_start_training` | Start a training run | | `rl_check_status` | Check training progress | | `rl_stop_training` | Stop a running training | | `rl_get_results` | Fetch WandB metrics | +| `rl_list_runs` | List active training runs | #### Dedicated RL CLI @@ -434,7 +428,7 @@ skills/ If you prefer not to use the installer: ```bash -# Clone the repository +# Clone the repository (with submodules) git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git cd hermes-agent @@ -445,6 +439,11 @@ cd hermes-agent python3 -m venv venv source venv/bin/activate pip install -e ".[all]" + +# Install submodules (required for terminal and RL tools) +pip install -e "./mini-swe-agent" # Terminal tool backend +pip install -e "./tinker-atropos" # RL training backend + hermes setup ``` diff --git a/model_tools.py b/model_tools.py index 847e56ef..e95a595c 100644 --- a/model_tools.py +++ b/model_tools.py @@ -665,7 +665,7 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]: "models": { "type": "array", "items": {"type": "string"}, - "description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, zhipu-ai/glm-4-flash, minimax/minimax-m1" + "description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, z-ai/glm-4.7-flash, minimax/minimax-m2.1" } }, "required": [] @@ -730,7 +730,7 @@ def get_all_tool_names() -> List[str]: "rl_get_current_config", "rl_edit_config", "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", - "rl_list_runs" + "rl_list_runs", "rl_test_inference" ]) return tool_names @@ -898,7 +898,7 @@ def get_tool_definitions( "rl_get_current_config", "rl_edit_config", "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", - "rl_list_runs" + "rl_list_runs", "rl_test_inference" ] } legacy_tools = legacy_map.get(toolset_name, []) @@ -950,7 +950,7 @@ def get_tool_definitions( "rl_get_current_config", "rl_edit_config", "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", - "rl_list_runs" + "rl_list_runs", "rl_test_inference" ] } legacy_tools = legacy_map.get(toolset_name, []) @@ -1407,7 +1407,7 @@ def handle_function_call( "rl_get_current_config", "rl_edit_config", "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", - "rl_list_runs" + "rl_list_runs", "rl_test_inference" ]: return handle_rl_function_call(function_name, function_args) diff --git a/rl_cli.py b/rl_cli.py index fe0eecfd..a45c365b 100644 --- a/rl_cli.py +++ b/rl_cli.py @@ -25,14 +25,34 @@ import sys from pathlib import Path import fire +import yaml # Load environment variables from .env file from dotenv import load_dotenv -env_path = Path(__file__).parent / '.env' -if env_path.exists(): - load_dotenv(dotenv_path=env_path) - print(f"āœ… Loaded environment variables from {env_path}") +# Load from ~/.hermes/.env first, then local .env +hermes_env_path = Path.home() / '.hermes' / '.env' +local_env_path = Path(__file__).parent / '.env' + +if hermes_env_path.exists(): + load_dotenv(dotenv_path=hermes_env_path) + print(f"āœ… Loaded environment variables from {hermes_env_path}") +elif local_env_path.exists(): + load_dotenv(dotenv_path=local_env_path) + print(f"āœ… Loaded environment variables from {local_env_path}") + +# Set terminal working directory to tinker-atropos submodule +# This ensures terminal commands run in the right context for RL work +tinker_atropos_dir = Path(__file__).parent / 'tinker-atropos' +if tinker_atropos_dir.exists(): + os.environ['TERMINAL_CWD'] = str(tinker_atropos_dir) + os.environ['HERMES_QUIET'] = '1' # Disable temp subdirectory creation + print(f"šŸ“‚ Terminal working directory: {tinker_atropos_dir}") +else: + # Fall back to hermes-agent directory if submodule not found + os.environ['TERMINAL_CWD'] = str(Path(__file__).parent) + os.environ['HERMES_QUIET'] = '1' + print(f"āš ļø tinker-atropos submodule not found, using: {Path(__file__).parent}") # Import agent and tools from run_agent import AIAgent @@ -40,6 +60,50 @@ from model_tools import get_tool_definitions, check_toolset_requirements from tools.rl_training_tool import check_rl_api_keys, get_missing_keys +# ============================================================================ +# Config Loading +# ============================================================================ + +DEFAULT_MODEL = "anthropic/claude-opus-4.5" +DEFAULT_BASE_URL = "https://openrouter.ai/api/v1" + + +def load_hermes_config() -> dict: + """ + Load configuration from ~/.hermes/config.yaml. + + Returns: + dict: Configuration with model, base_url, etc. + """ + config_path = Path.home() / '.hermes' / 'config.yaml' + + config = { + "model": DEFAULT_MODEL, + "base_url": DEFAULT_BASE_URL, + } + + if config_path.exists(): + try: + with open(config_path, "r") as f: + file_config = yaml.safe_load(f) or {} + + # Get model from config + if "model" in file_config: + if isinstance(file_config["model"], str): + config["model"] = file_config["model"] + elif isinstance(file_config["model"], dict): + config["model"] = file_config["model"].get("default", DEFAULT_MODEL) + + # Get base_url if specified + if "base_url" in file_config: + config["base_url"] = file_config["base_url"] + + except Exception as e: + print(f"āš ļø Warning: Failed to load config.yaml: {e}") + + return config + + # ============================================================================ # RL-Specific Configuration # ============================================================================ @@ -108,7 +172,7 @@ When asked to train a model, follow this workflow: """ # Toolsets to enable for RL workflows -RL_TOOLSETS = ["base", "terminal", "web", "rl"] +RL_TOOLSETS = ["terminal", "web", "rl"] # ============================================================================ @@ -172,9 +236,9 @@ def list_environments_sync(): def main( task: str = None, - model: str = "anthropic/claude-sonnet-4-20250514", + model: str = None, api_key: str = None, - base_url: str = "https://openrouter.ai/api/v1", + base_url: str = None, max_iterations: int = RL_MAX_ITERATIONS, interactive: bool = False, list_environments: bool = False, @@ -187,9 +251,9 @@ def main( Args: task: The training task/goal (e.g., "Train a model on GSM8k for math") - model: Model to use for the agent (default: claude-sonnet-4) + model: Model to use for the agent (reads from ~/.hermes/config.yaml if not provided) api_key: OpenRouter API key (uses OPENROUTER_API_KEY env var if not provided) - base_url: API base URL (default: OpenRouter) + base_url: API base URL (reads from config or defaults to OpenRouter) max_iterations: Maximum agent iterations (default: 200 for long workflows) interactive: Run in interactive mode (multiple conversations) list_environments: Just list available RL environments and exit @@ -210,6 +274,15 @@ def main( # Check server status python rl_cli.py --check-server """ + # Load config from ~/.hermes/config.yaml + config = load_hermes_config() + + # Use config values if not explicitly provided + if model is None: + model = config["model"] + if base_url is None: + base_url = config["base_url"] + print("šŸŽÆ RL Training Agent") print("=" * 60) diff --git a/run_agent.py b/run_agent.py index 7b70289f..1aceb5b5 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1764,10 +1764,16 @@ class AIAgent: self._invalid_tool_retries = 0 # Validate tool call arguments are valid JSON + # Handle empty strings as empty objects (common model quirk) invalid_json_args = [] for tc in assistant_message.tool_calls: + args = tc.function.arguments + # Treat empty/whitespace strings as empty object + if not args or not args.strip(): + tc.function.arguments = "{}" + continue try: - json.loads(tc.function.arguments) + json.loads(args) except json.JSONDecodeError as e: invalid_json_args.append((tc.function.name, str(e))) diff --git a/scripts/install.ps1 b/scripts/install.ps1 index caf80288..3666b21b 100644 --- a/scripts/install.ps1 +++ b/scripts/install.ps1 @@ -150,14 +150,15 @@ function Install-Repository { } } else { # Try SSH first (for private repo access), fall back to HTTPS + # Use --recurse-submodules to also clone mini-swe-agent and tinker-atropos Write-Info "Trying SSH clone..." - $sshResult = git clone --branch $Branch $RepoUrlSsh $InstallDir 2>&1 + $sshResult = git clone --branch $Branch --recurse-submodules $RepoUrlSsh $InstallDir 2>&1 if ($LASTEXITCODE -eq 0) { Write-Success "Cloned via SSH" } else { Write-Info "SSH failed, trying HTTPS..." - $httpsResult = git clone --branch $Branch $RepoUrlHttps $InstallDir 2>&1 + $httpsResult = git clone --branch $Branch --recurse-submodules $RepoUrlHttps $InstallDir 2>&1 if ($LASTEXITCODE -eq 0) { Write-Success "Cloned via HTTPS" @@ -171,6 +172,13 @@ function Install-Repository { } } + # Ensure submodules are initialized and updated (for existing installs or if --recurse failed) + Write-Info "Initializing submodules (mini-swe-agent, tinker-atropos)..." + Push-Location $InstallDir + git submodule update --init --recursive + Pop-Location + Write-Success "Submodules ready" + Write-Success "Repository ready" } @@ -208,15 +216,43 @@ function Install-Dependencies { & .\venv\Scripts\Activate.ps1 } + # Install main package try { pip install -e ".[all]" 2>&1 | Out-Null } catch { pip install -e "." | Out-Null } + Write-Success "Main package installed" + + # Install submodules + Write-Info "Installing mini-swe-agent (terminal tool backend)..." + if (Test-Path "mini-swe-agent\pyproject.toml") { + try { + pip install -e ".\mini-swe-agent" 2>&1 | Out-Null + Write-Success "mini-swe-agent installed" + } catch { + Write-Warning "mini-swe-agent install failed (terminal tools may not work)" + } + } else { + Write-Warning "mini-swe-agent not found (run: git submodule update --init)" + } + + Write-Info "Installing tinker-atropos (RL training backend)..." + if (Test-Path "tinker-atropos\pyproject.toml") { + try { + pip install -e ".\tinker-atropos" 2>&1 | Out-Null + Write-Success "tinker-atropos installed" + } catch { + Write-Warning "tinker-atropos install failed (RL tools may not work)" + } + } else { + Write-Warning "tinker-atropos not found (run: git submodule update --init)" + } + Pop-Location - Write-Success "Dependencies installed" + Write-Success "All dependencies installed" } function Set-PathVariable { diff --git a/scripts/install.sh b/scripts/install.sh index 463a0d5b..4b8affaa 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -292,12 +292,13 @@ clone_repo() { fi else # Try SSH first (for private repo access), fall back to HTTPS + # Use --recurse-submodules to also clone mini-swe-agent and tinker-atropos log_info "Trying SSH clone..." - if git clone --branch "$BRANCH" "$REPO_URL_SSH" "$INSTALL_DIR" 2>/dev/null; then + if git clone --branch "$BRANCH" --recurse-submodules "$REPO_URL_SSH" "$INSTALL_DIR" 2>/dev/null; then log_success "Cloned via SSH" else log_info "SSH failed, trying HTTPS..." - if git clone --branch "$BRANCH" "$REPO_URL_HTTPS" "$INSTALL_DIR"; then + if git clone --branch "$BRANCH" --recurse-submodules "$REPO_URL_HTTPS" "$INSTALL_DIR"; then log_success "Cloned via HTTPS" else log_error "Failed to clone repository" @@ -310,6 +311,12 @@ clone_repo() { fi cd "$INSTALL_DIR" + + # Ensure submodules are initialized and updated (for existing installs or if --recurse failed) + log_info "Initializing submodules (mini-swe-agent, tinker-atropos)..." + git submodule update --init --recursive + log_success "Submodules ready" + log_success "Repository ready" } @@ -343,10 +350,29 @@ install_deps() { source venv/bin/activate fi - # Install the package in editable mode with all extras + # Install the main package in editable mode with all extras pip install -e ".[all]" > /dev/null 2>&1 || pip install -e "." > /dev/null - log_success "Dependencies installed" + log_success "Main package installed" + + # Install submodules + log_info "Installing mini-swe-agent (terminal tool backend)..." + if [ -d "mini-swe-agent" ] && [ -f "mini-swe-agent/pyproject.toml" ]; then + pip install -e "./mini-swe-agent" > /dev/null 2>&1 || log_warn "mini-swe-agent install failed (terminal tools may not work)" + log_success "mini-swe-agent installed" + else + log_warn "mini-swe-agent not found (run: git submodule update --init)" + fi + + log_info "Installing tinker-atropos (RL training backend)..." + if [ -d "tinker-atropos" ] && [ -f "tinker-atropos/pyproject.toml" ]; then + pip install -e "./tinker-atropos" > /dev/null 2>&1 || log_warn "tinker-atropos install failed (RL tools may not work)" + log_success "tinker-atropos installed" + else + log_warn "tinker-atropos not found (run: git submodule update --init)" + fi + + log_success "All dependencies installed" } setup_path() { diff --git a/tools/rl_training_tool.py b/tools/rl_training_tool.py index 3c257c4c..8c18bee6 100644 --- a/tools/rl_training_tool.py +++ b/tools/rl_training_tool.py @@ -37,6 +37,7 @@ import subprocess import sys import time import uuid +from datetime import datetime import yaml from dataclasses import dataclass, field from pathlib import Path @@ -84,6 +85,7 @@ LOCKED_FIELDS = { "weight": 1.0, "num_requests_for_eval": 256, "timeout": 3600, + "server_type": "sglang", # Tinker uses sglang for actual training } ], "tinker": { @@ -211,6 +213,9 @@ def _scan_environments() -> List[EnvironmentInfo]: def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]: """ Dynamically import an environment and extract its config fields. + + Uses config_init() to get the actual config class, with fallback to + directly importing BaseEnvConfig if config_init fails. """ try: # Load the environment module @@ -230,15 +235,38 @@ def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]: if not env_class: return {} - # Call config_init to get the actual config - env_config, server_configs = env_class.config_init() - config_class = type(env_config) + # Try calling config_init to get the actual config class + config_class = None + try: + env_config, server_configs = env_class.config_init() + config_class = type(env_config) + except Exception as config_error: + # Fallback: try to import BaseEnvConfig directly from atroposlib + print(f"Note: config_init failed ({config_error}), using BaseEnvConfig defaults") + try: + from atroposlib.envs.base import BaseEnvConfig + config_class = BaseEnvConfig + except ImportError: + return {} + + if not config_class: + return {} + + # Helper to make values JSON-serializable (handle enums, etc.) + def make_serializable(val): + if val is None: + return None + if hasattr(val, 'value'): # Enum + return val.value + if hasattr(val, 'name') and hasattr(val, '__class__') and 'Enum' in str(type(val)): + return val.name + return val # Extract fields from the Pydantic model fields = {} for field_name, field_info in config_class.model_fields.items(): field_type = field_info.annotation - default = field_info.default + default = make_serializable(field_info.default) description = field_info.description or "" is_locked = field_name in LOCKED_FIELD_NAMES @@ -248,12 +276,15 @@ def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]: if hasattr(field_type, "__origin__"): type_name = str(field_type) + locked_value = LOCKED_FIELDS.get("env", {}).get(field_name, default) + current_value = make_serializable(locked_value) if is_locked else default + fields[field_name] = { "type": type_name, - "default": default if default is not None else None, + "default": default, "description": description, "locked": is_locked, - "current_value": LOCKED_FIELDS.get("env", {}).get(field_name, default) if is_locked else default, + "current_value": current_value, } return fields @@ -315,7 +346,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): trainer_log_file = open(trainer_log, "w") run_state.trainer_process = subprocess.Popen( - ["python", "launch_training.py", "--config", str(config_path)], + [sys.executable, "launch_training.py", "--config", str(config_path)], stdout=trainer_log_file, stderr=subprocess.STDOUT, cwd=str(TINKER_ATROPOS_ROOT), @@ -355,7 +386,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): env_log_file = open(env_log, "w") run_state.env_process = subprocess.Popen( - ["python", str(env_info.file_path), "serve", "--config", str(config_path)], + [sys.executable, str(env_info.file_path), "serve", "--config", str(config_path)], stdout=env_log_file, stderr=subprocess.STDOUT, cwd=str(TINKER_ATROPOS_ROOT), @@ -543,17 +574,14 @@ async def rl_select_environment(name: str) -> str: if not field_info.get("locked", False): _current_config[field_name] = field_info.get("default") - configurable_count = sum(1 for f in config_fields.values() if not f.get("locked", False)) - locked_count = sum(1 for f in config_fields.values() if f.get("locked", False)) + # Auto-set wandb_name to "{env_name}-DATETIME" to avoid overlaps + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + _current_config["wandb_name"] = f"{name}-{timestamp}" return json.dumps({ "message": f"Selected environment: {name}", "environment": name, "file_path": env_info.file_path, - "configurable_fields": configurable_count, - "locked_fields": locked_count, - "config": _current_config, - "tip": f"Use rl_get_current_config() to see all {configurable_count} configurable fields.", }, indent=2) @@ -961,10 +989,11 @@ async def rl_list_runs() -> str: # ============================================================================ # Test models at different scales for robustness testing +# These are cheap, capable models on OpenRouter for testing parsing/scoring TEST_MODELS = [ {"id": "qwen/qwen3-8b", "name": "Qwen3 8B", "scale": "small"}, - {"id": "zhipu-ai/glm-4-flash", "name": "GLM-4 Flash", "scale": "medium"}, - {"id": "minimax/minimax-m1", "name": "MiniMax M1", "scale": "large"}, + {"id": "z-ai/glm-4.7-flash", "name": "GLM-4.7 Flash", "scale": "medium"}, + {"id": "minimax/minimax-m2.1", "name": "MiniMax M2.1", "scale": "large"}, ] # Default test parameters - quick but representative @@ -1066,18 +1095,35 @@ async def rl_test_inference( # Build the process command using Atropos's built-in CLI # This runs the environment's actual code with OpenRouter as the inference backend + # We pass our locked settings + test-specific overrides via CLI args cmd = [ - "python", env_info.file_path, "process", + sys.executable, env_info.file_path, "process", + # Test-specific overrides "--env.total_steps", str(num_steps), "--env.group_size", str(group_size), - "--env.use_wandb", "false", + "--env.use_wandb", "false", # No wandb for quick tests "--env.data_path_to_save_groups", str(output_file), + # Use locked settings from our config + "--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"], + "--env.max_token_length", str(LOCKED_FIELDS["env"]["max_token_length"]), + "--env.max_num_workers", str(LOCKED_FIELDS["env"]["max_num_workers"]), + "--env.max_batches_offpolicy", str(LOCKED_FIELDS["env"]["max_batches_offpolicy"]), + # OpenRouter config for inference testing + # IMPORTANT: Use server_type=openai for OpenRouter (not sglang) + # sglang is only for actual training with Tinker's inference server "--openai.base_url", "https://openrouter.ai/api/v1", "--openai.api_key", api_key, "--openai.model_name", model_id, + "--openai.server_type", "openai", # OpenRouter is OpenAI-compatible + "--openai.health_check", "false", # OpenRouter doesn't have health endpoint ] - print(f"Running: python {Path(env_info.file_path).name} process ...") + # Debug: Print the full command + cmd_str = " ".join(str(c) for c in cmd) + # Hide API key in printed output + cmd_display = cmd_str.replace(api_key, "***API_KEY***") + print(f"Command: {cmd_display}") + print(f"Working dir: {TINKER_ATROPOS_ROOT}") print(f" {num_steps} steps Ɨ {group_size} completions = {total_rollouts_per_model} rollouts") model_results = { @@ -1105,12 +1151,44 @@ async def rl_test_inference( timeout=600, # 10 minute timeout per model ) + # Decode output + stdout_text = stdout.decode() if stdout else "" + stderr_text = stderr.decode() if stderr else "" + + # Write logs to files for inspection outside CLI + log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log" + with open(log_file, "w") as f: + f.write(f"Command: {cmd_display}\n") + f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n") + f.write(f"Return code: {process.returncode}\n") + f.write(f"\n{'='*60}\n") + f.write(f"STDOUT:\n{'='*60}\n") + f.write(stdout_text or "(empty)\n") + f.write(f"\n{'='*60}\n") + f.write(f"STDERR:\n{'='*60}\n") + f.write(stderr_text or "(empty)\n") + + print(f" Log file: {log_file}") + + # Print to console for immediate debugging + if stdout_text.strip(): + print(f"\n--- STDOUT ---") + print(stdout_text[-2000:]) # Last 2000 chars + + if stderr_text.strip(): + print(f"\n--- STDERR ---") + print(stderr_text[-2000:]) # Last 2000 chars + if process.returncode != 0: model_results["error"] = f"Process exited with code {process.returncode}" - model_results["stderr"] = stderr.decode()[-1000:] - print(f" Error: {model_results['error']}") + model_results["stderr"] = stderr_text[-1000:] + model_results["stdout"] = stdout_text[-1000:] + model_results["log_file"] = str(log_file) + print(f"\n āŒ Error: {model_results['error']}") else: - print(f" Process completed successfully") + print(f"\n āœ… Process completed successfully") + print(f" Output file: {output_file}") + print(f" File exists: {output_file.exists()}") # Parse the output JSONL file if output_file.exists(): From 5c3105b4376c7422b7c0c0f76e487f14a72a3e38 Mon Sep 17 00:00:00 2001 From: teknium1 Date: Wed, 4 Feb 2026 21:07:07 -0800 Subject: [PATCH 5/6] Enhance RL test inference with WandB integration and real-time output streaming - Added unique run ID generation for WandB tracking during test inference. - Enabled WandB usage for test tracking and updated command-line arguments accordingly. - Implemented real-time output streaming for process execution, improving log visibility and debugging. - Enhanced error handling to display last few lines of stderr for better troubleshooting. --- tools/rl_training_tool.py | 67 ++++++++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/tools/rl_training_tool.py b/tools/rl_training_tool.py index 8c18bee6..770c542c 100644 --- a/tools/rl_training_tool.py +++ b/tools/rl_training_tool.py @@ -1093,6 +1093,10 @@ async def rl_test_inference( # Output file for this test run output_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.jsonl" + # Generate unique run ID for wandb + test_run_id = str(uuid.uuid4())[:8] + wandb_run_name = f"test_inference_RSIAgent_{_current_env}_{test_run_id}" + # Build the process command using Atropos's built-in CLI # This runs the environment's actual code with OpenRouter as the inference backend # We pass our locked settings + test-specific overrides via CLI args @@ -1101,7 +1105,8 @@ async def rl_test_inference( # Test-specific overrides "--env.total_steps", str(num_steps), "--env.group_size", str(group_size), - "--env.use_wandb", "false", # No wandb for quick tests + "--env.use_wandb", "true", # Enable wandb for test tracking + "--env.wandb_name", wandb_run_name, "--env.data_path_to_save_groups", str(output_file), # Use locked settings from our config "--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"], @@ -1124,12 +1129,14 @@ async def rl_test_inference( cmd_display = cmd_str.replace(api_key, "***API_KEY***") print(f"Command: {cmd_display}") print(f"Working dir: {TINKER_ATROPOS_ROOT}") + print(f"WandB run: {wandb_run_name}") print(f" {num_steps} steps Ɨ {group_size} completions = {total_rollouts_per_model} rollouts") model_results = { "model": model_id, "name": model_info["name"], "scale": model_info["scale"], + "wandb_run": wandb_run_name, "output_file": str(output_file), "steps": [], "steps_tested": 0, @@ -1138,7 +1145,7 @@ async def rl_test_inference( } try: - # Run the process command + # Run the process command with real-time output streaming process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, @@ -1146,17 +1153,43 @@ async def rl_test_inference( cwd=str(TINKER_ATROPOS_ROOT), ) - stdout, stderr = await asyncio.wait_for( - process.communicate(), - timeout=600, # 10 minute timeout per model - ) + # Stream output in real-time while collecting for logs + stdout_lines = [] + stderr_lines = [] + log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log" - # Decode output - stdout_text = stdout.decode() if stdout else "" - stderr_text = stderr.decode() if stderr else "" + async def read_stream(stream, lines_list, prefix=""): + """Read stream line by line and print in real-time.""" + while True: + line = await stream.readline() + if not line: + break + decoded = line.decode().rstrip() + lines_list.append(decoded) + # Print progress-related lines in real-time + if any(kw in decoded.lower() for kw in ['processing', 'group', 'step', 'progress', '%', 'completed']): + print(f" {prefix}{decoded}") + + # Read both streams concurrently with timeout + try: + await asyncio.wait_for( + asyncio.gather( + read_stream(process.stdout, stdout_lines, "šŸ“Š "), + read_stream(process.stderr, stderr_lines, "āš ļø "), + ), + timeout=600, # 10 minute timeout per model + ) + except asyncio.TimeoutError: + process.kill() + raise + + await process.wait() + + # Combine output for logging + stdout_text = "\n".join(stdout_lines) + stderr_text = "\n".join(stderr_lines) # Write logs to files for inspection outside CLI - log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log" with open(log_file, "w") as f: f.write(f"Command: {cmd_display}\n") f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n") @@ -1170,21 +1203,17 @@ async def rl_test_inference( print(f" Log file: {log_file}") - # Print to console for immediate debugging - if stdout_text.strip(): - print(f"\n--- STDOUT ---") - print(stdout_text[-2000:]) # Last 2000 chars - - if stderr_text.strip(): - print(f"\n--- STDERR ---") - print(stderr_text[-2000:]) # Last 2000 chars - if process.returncode != 0: model_results["error"] = f"Process exited with code {process.returncode}" model_results["stderr"] = stderr_text[-1000:] model_results["stdout"] = stdout_text[-1000:] model_results["log_file"] = str(log_file) print(f"\n āŒ Error: {model_results['error']}") + # Print last few lines of stderr for debugging + if stderr_lines: + print(f" Last errors:") + for line in stderr_lines[-5:]: + print(f" {line}") else: print(f"\n āœ… Process completed successfully") print(f" Output file: {output_file}") From 533c064269417d4c213aa8393e3a6098a78fb5d1 Mon Sep 17 00:00:00 2001 From: teknium1 Date: Thu, 5 Feb 2026 03:49:46 -0800 Subject: [PATCH 6/6] Add file manipulation tools and enhance setup scripts - Introduced file manipulation capabilities in `model_tools.py`, including functions for reading, writing, patching, and searching files. - Added a new `file` toolset in `toolsets.py` and updated distributions to include file tools. - Enhanced `setup-hermes.sh` and `install.sh` scripts to check for and optionally install `ripgrep` for faster file searching. - Implemented a new `file_operations.py` module to encapsulate file operations using shell commands. - Updated `doctor.py` and `install.ps1` to check for `ripgrep` and provide installation guidance if not found. - Added fuzzy matching and patch parsing capabilities to improve file manipulation accuracy and flexibility. --- hermes_cli/doctor.py | 7 + model_tools.py | 266 ++++++++++- scripts/install.ps1 | 89 +++- scripts/install.sh | 124 ++++++ setup-hermes.sh | 47 ++ tools/__init__.py | 24 + tools/file_operations.py | 937 +++++++++++++++++++++++++++++++++++++++ tools/file_tools.py | 113 +++++ tools/fuzzy_match.py | 478 ++++++++++++++++++++ tools/patch_parser.py | 439 ++++++++++++++++++ toolset_distributions.py | 24 +- toolsets.py | 14 +- 12 files changed, 2549 insertions(+), 13 deletions(-) create mode 100644 tools/file_operations.py create mode 100644 tools/file_tools.py create mode 100644 tools/fuzzy_match.py create mode 100644 tools/patch_parser.py diff --git a/hermes_cli/doctor.py b/hermes_cli/doctor.py index 82b7e541..5e0ee39f 100644 --- a/hermes_cli/doctor.py +++ b/hermes_cli/doctor.py @@ -167,6 +167,13 @@ def run_doctor(args): else: check_warn("git not found", "(optional)") + # ripgrep (optional, for faster file search) + if shutil.which("rg"): + check_ok("ripgrep (rg)", "(faster file search)") + else: + check_warn("ripgrep (rg) not found", "(file search uses grep fallback)") + check_info("Install for faster search: sudo apt install ripgrep") + # Docker (optional) terminal_env = os.getenv("TERMINAL_ENV", "local") if terminal_env == "docker": diff --git a/model_tools.py b/model_tools.py index e95a595c..203a6669 100644 --- a/model_tools.py +++ b/model_tools.py @@ -33,6 +33,9 @@ from typing import Dict, Any, List, Optional, Tuple from tools.web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key from tools.terminal_tool import terminal_tool, check_terminal_requirements, TERMINAL_TOOL_DESCRIPTION, cleanup_vm +# File manipulation tools (read, write, patch, search) +from tools.file_tools import read_file_tool, write_file_tool, patch_tool, search_tool +from tools import check_file_requirements # Hecate/MorphCloud terminal tool (cloud VMs) - available as alternative backend from tools.terminal_hecate import terminal_hecate_tool, check_hecate_requirements, TERMINAL_HECATE_DESCRIPTION from tools.vision_tools import vision_analyze_tool, check_vision_requirements @@ -155,6 +158,13 @@ TOOLSET_REQUIREMENTS = { "rl_list_runs", "rl_test_inference", ], }, + "file": { + "name": "File Operations (read, write, patch, search)", + "env_vars": [], # Uses terminal backend, no additional requirements + "check_fn": check_file_requirements, + "setup_url": None, + "tools": ["read_file", "write_file", "patch", "search"], + }, } @@ -675,6 +685,163 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]: ] +def get_file_tool_definitions() -> List[Dict[str, Any]]: + """ + Get tool definitions for file manipulation tools in OpenAI's expected format. + + File tools operate via the terminal backend and support any environment + (local, docker, singularity, ssh, modal). + + Returns: + List[Dict]: List of file tool definitions compatible with OpenAI API + """ + return [ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read a file with pagination support. Returns content with line numbers in 'LINE_NUM|CONTENT' format. For binary files (images), returns base64-encoded data. If file not found, suggests similar filenames.", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file to read (absolute or relative)" + }, + "offset": { + "type": "integer", + "description": "Line number to start reading from (1-indexed, default: 1)", + "default": 1, + "minimum": 1 + }, + "limit": { + "type": "integer", + "description": "Maximum number of lines to read (default: 500, max: 2000)", + "default": 500, + "maximum": 2000 + } + }, + "required": ["path"] + } + } + }, + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write content to a file. Creates parent directories automatically. Returns bytes written and lint check results for supported languages.", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file to write (will be created if doesn't exist)" + }, + "content": { + "type": "string", + "description": "Content to write to the file" + } + }, + "required": ["path", "content"] + } + } + }, + { + "type": "function", + "function": { + "name": "patch", + "description": "Modify files using either simple string replacement or V4A patch format. Mode 'replace' does find-and-replace with fuzzy matching. Mode 'patch' applies multi-file changes using V4A format (*** Begin/End Patch). Auto-runs syntax checks on modified files.", + "parameters": { + "type": "object", + "properties": { + "mode": { + "type": "string", + "enum": ["replace", "patch"], + "description": "Edit mode: 'replace' for string replacement, 'patch' for V4A patch format", + "default": "replace" + }, + "path": { + "type": "string", + "description": "File path (required for 'replace' mode)" + }, + "old_string": { + "type": "string", + "description": "Text to find and replace (required for 'replace' mode). Must be unique in file unless replace_all=true" + }, + "new_string": { + "type": "string", + "description": "Replacement text (required for 'replace' mode)" + }, + "replace_all": { + "type": "boolean", + "description": "Replace all occurrences instead of requiring unique match (default: false)", + "default": False + }, + "patch": { + "type": "string", + "description": "V4A format patch content (required for 'patch' mode). Format: *** Begin Patch / *** Update File: path / @@ context @@ / -removed / +added / *** End Patch" + } + }, + "required": ["mode"] + } + } + }, + { + "type": "function", + "function": { + "name": "search", + "description": "Search for content in files or search for files by name. Use target='content' to search inside files (like grep), or target='files' to find files by name pattern (like glob/find). Results sorted by modification time (newest first).", + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "For target='content': regex pattern to search for. For target='files': glob pattern (e.g., '*.py', '*config*')" + }, + "target": { + "type": "string", + "enum": ["content", "files"], + "description": "Search mode: 'content' searches inside files, 'files' searches for files by name", + "default": "content" + }, + "path": { + "type": "string", + "description": "Directory or file to search in (default: current directory)", + "default": "." + }, + "file_glob": { + "type": "string", + "description": "Filter files by pattern when target='content' (e.g., '*.py' to only search Python files)" + }, + "limit": { + "type": "integer", + "description": "Maximum number of results (default: 50)", + "default": 50 + }, + "offset": { + "type": "integer", + "description": "Skip first N results for pagination (default: 0)", + "default": 0 + }, + "output_mode": { + "type": "string", + "enum": ["content", "files_only", "count"], + "description": "For target='content': 'content' shows matches, 'files_only' shows file paths, 'count' shows match counts per file", + "default": "content" + }, + "context": { + "type": "integer", + "description": "Lines of context around matches (only for target='content', output_mode='content')", + "default": 0 + } + }, + "required": ["pattern"] + } + } + } + ] + + def get_all_tool_names() -> List[str]: """ Get the names of all available tools across all toolsets. @@ -733,6 +900,12 @@ def get_all_tool_names() -> List[str]: "rl_list_runs", "rl_test_inference" ]) + # File manipulation tools (use terminal backend) + if check_file_requirements(): + tool_names.extend([ + "read_file", "write_file", "patch", "search" + ]) + return tool_names @@ -782,6 +955,11 @@ def get_toolset_for_tool(tool_name: str) -> str: "rl_stop_training": "rl_tools", "rl_get_results": "rl_tools", "rl_list_runs": "rl_tools", + # File manipulation tools + "read_file": "file_tools", + "write_file": "file_tools", + "patch": "file_tools", + "search": "file_tools", } return toolset_mapping.get(tool_name, "unknown") @@ -864,6 +1042,11 @@ def get_tool_definitions( for tool in get_rl_tool_definitions(): all_available_tools_map[tool["function"]["name"]] = tool + # File manipulation tools (use terminal backend) + if check_file_requirements(): + for tool in get_file_tool_definitions(): + all_available_tools_map[tool["function"]["name"]] = tool + # Determine which tools to include based on toolsets tools_to_include = set() @@ -899,7 +1082,8 @@ def get_tool_definitions( "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", "rl_list_runs", "rl_test_inference" - ] + ], + "file_tools": ["read_file", "write_file", "patch", "search"] } legacy_tools = legacy_map.get(toolset_name, []) tools_to_include.update(legacy_tools) @@ -951,7 +1135,8 @@ def get_tool_definitions( "rl_start_training", "rl_check_status", "rl_stop_training", "rl_get_results", "rl_list_runs", "rl_test_inference" - ] + ], + "file_tools": ["read_file", "write_file", "patch", "search"] } legacy_tools = legacy_map.get(toolset_name, []) tools_to_include.difference_update(legacy_tools) @@ -1338,6 +1523,70 @@ def handle_rl_function_call( return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False) +def handle_file_function_call( + function_name: str, + function_args: Dict[str, Any], + task_id: Optional[str] = None +) -> str: + """ + Handle function calls for file manipulation tools. + + These tools use the terminal backend for all operations, supporting + local, docker, singularity, ssh, and modal environments. + + Args: + function_name (str): Name of the file function to call + function_args (Dict): Arguments for the function + task_id (str): Task identifier for environment isolation + + Returns: + str: Function result as JSON string + """ + # Determine task_id to use + tid = task_id or "default" + + if function_name == "read_file": + return read_file_tool( + path=function_args.get("path", ""), + offset=function_args.get("offset", 1), + limit=function_args.get("limit", 500), + task_id=tid + ) + + elif function_name == "write_file": + return write_file_tool( + path=function_args.get("path", ""), + content=function_args.get("content", ""), + task_id=tid + ) + + elif function_name == "patch": + return patch_tool( + mode=function_args.get("mode", "replace"), + path=function_args.get("path"), + old_string=function_args.get("old_string"), + new_string=function_args.get("new_string"), + replace_all=function_args.get("replace_all", False), + patch=function_args.get("patch"), + task_id=tid + ) + + elif function_name == "search": + return search_tool( + pattern=function_args.get("pattern", ""), + target=function_args.get("target", "content"), + path=function_args.get("path", "."), + file_glob=function_args.get("file_glob"), + limit=function_args.get("limit", 50), + offset=function_args.get("offset", 0), + output_mode=function_args.get("output_mode", "content"), + context=function_args.get("context", 0), + task_id=tid + ) + + return json.dumps({"error": f"Unknown file function: {function_name}"}, ensure_ascii=False) + + def handle_function_call( function_name: str, function_args: Dict[str, Any], @@ -1411,6 +1660,10 @@ def handle_function_call( ]: return handle_rl_function_call(function_name, function_args) + # Route file manipulation tools + elif function_name in ["read_file", "write_file", "patch", "search"]: + return handle_file_function_call(function_name, function_args, task_id) + else: error_msg = f"Unknown function: {function_name}" print(f"āŒ {error_msg}") @@ -1482,6 +1735,12 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]: "tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"], "description": "Schedule and manage automated tasks (cronjobs) - only available in interactive CLI mode", "requirements": ["HERMES_INTERACTIVE=1 (set automatically by cli.py)"] + }, + "file_tools": { + "available": check_file_requirements(), + "tools": ["read_file", "write_file", "patch", "search"], + "description": "File manipulation tools: read/write files, search content/files, patch with fuzzy matching", + "requirements": ["Terminal backend available (local/docker/ssh/singularity/modal)"] } } @@ -1502,7 +1761,8 @@ def check_toolset_requirements() -> Dict[str, bool]: "image_tools": check_image_generation_requirements(), "skills_tools": check_skills_requirements(), "browser_tools": check_browser_requirements(), - "cronjob_tools": check_cronjob_requirements() + "cronjob_tools": check_cronjob_requirements(), + "file_tools": check_file_requirements() } if __name__ == "__main__": diff --git a/scripts/install.ps1 b/scripts/install.ps1 index 3666b21b..8170abba 100644 --- a/scripts/install.ps1 +++ b/scripts/install.ps1 @@ -128,6 +128,78 @@ function Test-Node { return $true # Don't fail - Node is optional } +function Test-Ripgrep { + Write-Info "Checking ripgrep (optional, for faster file search)..." + + if (Get-Command rg -ErrorAction SilentlyContinue) { + $version = rg --version | Select-Object -First 1 + Write-Success "$version found" + $script:HasRipgrep = $true + return $true + } + + Write-Warning "ripgrep not found (file search will use findstr fallback)" + + # Check what package managers are available + $hasWinget = Get-Command winget -ErrorAction SilentlyContinue + $hasChoco = Get-Command choco -ErrorAction SilentlyContinue + $hasScoop = Get-Command scoop -ErrorAction SilentlyContinue + + # Offer to install + Write-Host "" + $response = Read-Host "Would you like to install ripgrep? (faster search, recommended) [Y/n]" + + if ($response -eq "" -or $response -match "^[Yy]") { + Write-Info "Installing ripgrep..." + + if ($hasWinget) { + try { + winget install BurntSushi.ripgrep.MSVC --silent 2>&1 | Out-Null + if ($LASTEXITCODE -eq 0) { + Write-Success "ripgrep installed via winget" + $script:HasRipgrep = $true + return $true + } + } catch { } + } + + if ($hasChoco) { + try { + choco install ripgrep -y 2>&1 | Out-Null + if ($LASTEXITCODE -eq 0) { + Write-Success "ripgrep installed via chocolatey" + $script:HasRipgrep = $true + return $true + } + } catch { } + } + + if ($hasScoop) { + try { + scoop install ripgrep 2>&1 | Out-Null + if ($LASTEXITCODE -eq 0) { + Write-Success "ripgrep installed via scoop" + $script:HasRipgrep = $true + return $true + } + } catch { } + } + + Write-Warning "Auto-install failed. You can install manually:" + } else { + Write-Info "Skipping ripgrep installation. To install manually:" + } + + # Show manual install instructions + Write-Info " winget install BurntSushi.ripgrep.MSVC" + Write-Info " Or: choco install ripgrep" + Write-Info " Or: scoop install ripgrep" + Write-Info " Or download from: https://github.com/BurntSushi/ripgrep/releases" + + $script:HasRipgrep = $false + return $true # Don't fail - ripgrep is optional +} + # ============================================================================ # Installation # ============================================================================ @@ -405,6 +477,20 @@ function Write-Completion { Write-Host "" Write-Host "⚔ Restart your terminal for PATH changes to take effect" -ForegroundColor Yellow Write-Host "" + + # Show notes about optional tools + if (-not $HasNode) { + Write-Host "Note: Node.js was not found. Browser automation tools" -ForegroundColor Yellow + Write-Host "will have limited functionality." -ForegroundColor Yellow + Write-Host "" + } + + if (-not $HasRipgrep) { + Write-Host "Note: ripgrep (rg) was not found. File search will use" -ForegroundColor Yellow + Write-Host "findstr as a fallback. For faster search:" -ForegroundColor Yellow + Write-Host " winget install BurntSushi.ripgrep.MSVC" -ForegroundColor Yellow + Write-Host "" + } } # ============================================================================ @@ -416,7 +502,8 @@ function Main { if (-not (Test-Python)) { exit 1 } if (-not (Test-Git)) { exit 1 } - Test-Node # Optional, doesn't fail + Test-Node # Optional, doesn't fail + Test-Ripgrep # Optional, doesn't fail Install-Repository Install-Venv diff --git a/scripts/install.sh b/scripts/install.sh index 4b8affaa..c3ff5a79 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -271,6 +271,120 @@ check_node() { # Don't exit - Node is optional } +check_ripgrep() { + log_info "Checking ripgrep (optional, for faster file search)..." + + if command -v rg &> /dev/null; then + RG_VERSION=$(rg --version | head -1) + log_success "$RG_VERSION found" + HAS_RIPGREP=true + return 0 + fi + + log_warn "ripgrep not found (file search will use grep fallback)" + + # Offer to install + echo "" + read -p "Would you like to install ripgrep? (faster search, recommended) [Y/n] " -n 1 -r + echo + + if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then + log_info "Installing ripgrep..." + + # Check if we can use sudo + CAN_SUDO=false + if command -v sudo &> /dev/null; then + # Check if user has sudo access (without actually running sudo) + if sudo -n true 2>/dev/null || sudo -v 2>/dev/null; then + CAN_SUDO=true + fi + fi + + case "$OS" in + linux) + if [ "$CAN_SUDO" = true ]; then + case "$DISTRO" in + ubuntu|debian) + if sudo apt install -y ripgrep 2>/dev/null; then + log_success "ripgrep installed" + HAS_RIPGREP=true + return 0 + fi + ;; + fedora) + if sudo dnf install -y ripgrep 2>/dev/null; then + log_success "ripgrep installed" + HAS_RIPGREP=true + return 0 + fi + ;; + arch) + if sudo pacman -S --noconfirm ripgrep 2>/dev/null; then + log_success "ripgrep installed" + HAS_RIPGREP=true + return 0 + fi + ;; + esac + else + log_warn "sudo not available - cannot auto-install system packages" + # Try cargo as fallback if available + if command -v cargo &> /dev/null; then + log_info "Trying cargo install (no sudo required)..." + if cargo install ripgrep 2>/dev/null; then + log_success "ripgrep installed via cargo" + HAS_RIPGREP=true + return 0 + fi + fi + fi + ;; + macos) + if command -v brew &> /dev/null; then + if brew install ripgrep 2>/dev/null; then + log_success "ripgrep installed" + HAS_RIPGREP=true + return 0 + fi + fi + ;; + esac + log_warn "Auto-install failed. You can install manually later:" + else + log_info "Skipping ripgrep installation. To install manually:" + fi + + # Show manual install instructions + case "$OS" in + linux) + case "$DISTRO" in + ubuntu|debian) + log_info " sudo apt install ripgrep" + ;; + fedora) + log_info " sudo dnf install ripgrep" + ;; + arch) + log_info " sudo pacman -S ripgrep" + ;; + *) + log_info " https://github.com/BurntSushi/ripgrep#installation" + ;; + esac + # Show cargo alternative for users without sudo + if command -v cargo &> /dev/null; then + log_info " Or without sudo: cargo install ripgrep" + fi + ;; + macos) + log_info " brew install ripgrep" + ;; + esac + + HAS_RIPGREP=false + # Don't exit - ripgrep is optional (grep fallback exists) +} + # ============================================================================ # Installation # ============================================================================ @@ -540,6 +654,15 @@ print_success() { echo "if you need full browser support." echo -e "${NC}" fi + + # Show ripgrep note if not installed + if [ "$HAS_RIPGREP" = false ]; then + echo -e "${YELLOW}" + echo "Note: ripgrep (rg) was not found. File search will use" + echo "grep as a fallback. For faster search in large codebases," + echo "install ripgrep: sudo apt install ripgrep (or brew install ripgrep)" + echo -e "${NC}" + fi } # ============================================================================ @@ -553,6 +676,7 @@ main() { check_python check_git check_node + check_ripgrep clone_repo setup_venv diff --git a/setup-hermes.sh b/setup-hermes.sh index 4cffdc73..e22511b3 100755 --- a/setup-hermes.sh +++ b/setup-hermes.sh @@ -80,6 +80,53 @@ pip install -e ".[all]" > /dev/null 2>&1 || pip install -e "." > /dev/null echo -e "${GREEN}āœ“${NC} Dependencies installed" +# ============================================================================ +# Optional: ripgrep (for faster file search) +# ============================================================================ + +echo -e "${CYAN}→${NC} Checking ripgrep (optional, for faster search)..." + +if command -v rg &> /dev/null; then + echo -e "${GREEN}āœ“${NC} ripgrep found" +else + echo -e "${YELLOW}⚠${NC} ripgrep not found (file search will use grep fallback)" + read -p "Install ripgrep for faster search? [Y/n] " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then + INSTALLED=false + + # Check if sudo is available + if command -v sudo &> /dev/null && sudo -n true 2>/dev/null; then + if command -v apt &> /dev/null; then + sudo apt install -y ripgrep && INSTALLED=true + elif command -v dnf &> /dev/null; then + sudo dnf install -y ripgrep && INSTALLED=true + fi + fi + + # Try brew (no sudo needed) + if [ "$INSTALLED" = false ] && command -v brew &> /dev/null; then + brew install ripgrep && INSTALLED=true + fi + + # Try cargo (no sudo needed) + if [ "$INSTALLED" = false ] && command -v cargo &> /dev/null; then + echo -e "${CYAN}→${NC} Trying cargo install (no sudo required)..." + cargo install ripgrep && INSTALLED=true + fi + + if [ "$INSTALLED" = true ]; then + echo -e "${GREEN}āœ“${NC} ripgrep installed" + else + echo -e "${YELLOW}⚠${NC} Auto-install failed. Install options:" + echo " sudo apt install ripgrep # Debian/Ubuntu" + echo " brew install ripgrep # macOS" + echo " cargo install ripgrep # With Rust (no sudo)" + echo " https://github.com/BurntSushi/ripgrep#installation" + fi + fi +fi + # ============================================================================ # Environment file # ============================================================================ diff --git a/tools/__init__.py b/tools/__init__.py index 0b6bcdcc..004a6add 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -111,6 +111,22 @@ from .rl_training_tool import ( get_missing_keys, ) +# File manipulation tools (read, write, patch, search) +from .file_tools import ( + read_file_tool, + write_file_tool, + patch_tool, + search_tool, + get_file_tools, + clear_file_ops_cache, +) + +# File tools have no external requirements - they use the terminal backend +def check_file_requirements(): + """File tools only require terminal backend to be available.""" + from .terminal_tool import check_terminal_requirements + return check_terminal_requirements() + __all__ = [ # Web tools 'web_search_tool', @@ -181,5 +197,13 @@ __all__ = [ 'rl_test_inference', 'check_rl_api_keys', 'get_missing_keys', + # File manipulation tools + 'read_file_tool', + 'write_file_tool', + 'patch_tool', + 'search_tool', + 'get_file_tools', + 'clear_file_ops_cache', + 'check_file_requirements', ] diff --git a/tools/file_operations.py b/tools/file_operations.py new file mode 100644 index 00000000..2509df3c --- /dev/null +++ b/tools/file_operations.py @@ -0,0 +1,937 @@ +#!/usr/bin/env python3 +""" +File Operations Module + +Provides file manipulation capabilities (read, write, patch, search) that work +across all terminal backends (local, docker, singularity, ssh, modal). + +The key insight is that all file operations can be expressed as shell commands, +so we wrap the terminal backend's execute() interface to provide a unified file API. + +Usage: + from tools.file_operations import ShellFileOperations + from tools.terminal_tool import _active_environments + + # Get file operations for a terminal environment + file_ops = ShellFileOperations(terminal_env) + + # Read a file + result = file_ops.read_file("/path/to/file.py") + + # Write a file + result = file_ops.write_file("/path/to/new.py", "print('hello')") + + # Search for content + result = file_ops.search("TODO", path=".", file_glob="*.py") +""" + +import os +import re +import json +import uuid +import difflib +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any, Tuple +from pathlib import Path + + +# ============================================================================= +# Result Data Classes +# ============================================================================= + +@dataclass +class ReadResult: + """Result from reading a file.""" + content: str = "" + total_lines: int = 0 + file_size: int = 0 + truncated: bool = False + hint: Optional[str] = None + is_binary: bool = False + is_image: bool = False + base64_content: Optional[str] = None + mime_type: Optional[str] = None + dimensions: Optional[str] = None # For images: "WIDTHxHEIGHT" + error: Optional[str] = None + similar_files: List[str] = field(default_factory=list) + + def to_dict(self) -> dict: + return {k: v for k, v in self.__dict__.items() if v is not None and v != [] and v != ""} + + +@dataclass +class WriteResult: + """Result from writing a file.""" + bytes_written: int = 0 + dirs_created: bool = False + error: Optional[str] = None + warning: Optional[str] = None + + def to_dict(self) -> dict: + return {k: v for k, v in self.__dict__.items() if v is not None} + + +@dataclass +class PatchResult: + """Result from patching a file.""" + success: bool = False + diff: str = "" + files_modified: List[str] = field(default_factory=list) + files_created: List[str] = field(default_factory=list) + files_deleted: List[str] = field(default_factory=list) + lint: Optional[Dict[str, Any]] = None + error: Optional[str] = None + + def to_dict(self) -> dict: + result = {"success": self.success} + if self.diff: + result["diff"] = self.diff + if self.files_modified: + result["files_modified"] = self.files_modified + if self.files_created: + result["files_created"] = self.files_created + if self.files_deleted: + result["files_deleted"] = self.files_deleted + if self.lint: + result["lint"] = self.lint + if self.error: + result["error"] = self.error + return result + + +@dataclass +class SearchMatch: + """A single search match.""" + path: str + line_number: int + content: str + mtime: float = 0.0 # Modification time for sorting + + +@dataclass +class SearchResult: + """Result from searching.""" + matches: List[SearchMatch] = field(default_factory=list) + files: List[str] = field(default_factory=list) + counts: Dict[str, int] = field(default_factory=dict) + total_count: int = 0 + truncated: bool = False + error: Optional[str] = None + + def to_dict(self) -> dict: + result = {"total_count": self.total_count} + if self.matches: + result["matches"] = [ + {"path": m.path, "line": m.line_number, "content": m.content} + for m in self.matches + ] + if self.files: + result["files"] = self.files + if self.counts: + result["counts"] = self.counts + if self.truncated: + result["truncated"] = True + if self.error: + result["error"] = self.error + return result + + +@dataclass +class LintResult: + """Result from linting a file.""" + success: bool = True + skipped: bool = False + output: str = "" + message: str = "" + + def to_dict(self) -> dict: + if self.skipped: + return {"status": "skipped", "message": self.message} + return { + "status": "ok" if self.success else "error", + "output": self.output + } + + +@dataclass +class ExecuteResult: + """Result from executing a shell command.""" + stdout: str = "" + exit_code: int = 0 + + +# ============================================================================= +# Abstract Interface +# ============================================================================= + +class FileOperations(ABC): + """Abstract interface for file operations across terminal backends.""" + + @abstractmethod + def read_file(self, path: str, offset: int = 1, limit: int = 500) -> ReadResult: + """Read a file with pagination support.""" + ... + + @abstractmethod + def write_file(self, path: str, content: str) -> WriteResult: + """Write content to a file, creating directories as needed.""" + ... + + @abstractmethod + def patch_replace(self, path: str, old_string: str, new_string: str, + replace_all: bool = False) -> PatchResult: + """Replace text in a file using fuzzy matching.""" + ... + + @abstractmethod + def patch_v4a(self, patch_content: str) -> PatchResult: + """Apply a V4A format patch.""" + ... + + @abstractmethod + def search(self, pattern: str, path: str = ".", target: str = "content", + file_glob: Optional[str] = None, limit: int = 50, offset: int = 0, + output_mode: str = "content", context: int = 0) -> SearchResult: + """Search for content or files.""" + ... + + +# ============================================================================= +# Shell-based Implementation +# ============================================================================= + +# Binary file extensions (fast path check) +BINARY_EXTENSIONS = { + # Images + '.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp', '.ico', '.tiff', '.tif', + '.svg', # SVG is text but often treated as binary + # Audio/Video + '.mp3', '.mp4', '.wav', '.avi', '.mov', '.mkv', '.flac', '.ogg', '.webm', + # Archives + '.zip', '.tar', '.gz', '.bz2', '.xz', '.7z', '.rar', + # Documents + '.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', + # Compiled/Binary + '.exe', '.dll', '.so', '.dylib', '.o', '.a', '.pyc', '.pyo', '.class', + '.wasm', '.bin', + # Fonts + '.ttf', '.otf', '.woff', '.woff2', '.eot', + # Other + '.db', '.sqlite', '.sqlite3', +} + +# Image extensions (subset of binary that we can return as base64) +IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp', '.ico'} + +# Linters by file extension +LINTERS = { + '.py': 'python -m py_compile {file} 2>&1', + '.js': 'node --check {file} 2>&1', + '.ts': 'npx tsc --noEmit {file} 2>&1', + '.go': 'go vet {file} 2>&1', + '.rs': 'rustfmt --check {file} 2>&1', +} + +# Max limits for read operations +MAX_LINES = 2000 +MAX_LINE_LENGTH = 2000 +MAX_FILE_SIZE = 50 * 1024 # 50KB + + +class ShellFileOperations(FileOperations): + """ + File operations implemented via shell commands. + + Works with ANY terminal backend that has execute(command, cwd) method. + This includes local, docker, singularity, ssh, and modal environments. + """ + + def __init__(self, terminal_env, cwd: str = None): + """ + Initialize file operations with a terminal environment. + + Args: + terminal_env: Any object with execute(command, cwd) method. + Returns {"output": str, "returncode": int} + cwd: Working directory (defaults to env's cwd or /tmp) + """ + self.env = terminal_env + # Determine cwd from various possible sources + self.cwd = cwd or getattr(terminal_env, 'cwd', None) or \ + getattr(getattr(terminal_env, 'config', None), 'cwd', None) or '/tmp' + + # Cache for command availability checks + self._command_cache: Dict[str, bool] = {} + + def _exec(self, command: str, cwd: str = None, timeout: int = None) -> ExecuteResult: + """Execute command via terminal backend.""" + kwargs = {} + if timeout: + kwargs['timeout'] = timeout + + result = self.env.execute(command, cwd=cwd or self.cwd, **kwargs) + return ExecuteResult( + stdout=result.get("output", ""), + exit_code=result.get("returncode", 0) + ) + + def _has_command(self, cmd: str) -> bool: + """Check if a command exists in the environment (cached).""" + if cmd not in self._command_cache: + result = self._exec(f"command -v {cmd} >/dev/null 2>&1 && echo 'yes'") + self._command_cache[cmd] = result.stdout.strip() == 'yes' + return self._command_cache[cmd] + + def _is_likely_binary(self, path: str, content_sample: str = None) -> bool: + """ + Check if a file is likely binary. + + Uses extension check (fast) + content analysis (fallback). + """ + ext = os.path.splitext(path)[1].lower() + if ext in BINARY_EXTENSIONS: + return True + + # Content analysis: >30% non-printable chars = binary + if content_sample: + if not content_sample: + return False + non_printable = sum(1 for c in content_sample[:1000] + if ord(c) < 32 and c not in '\n\r\t') + return non_printable / min(len(content_sample), 1000) > 0.30 + + return False + + def _is_image(self, path: str) -> bool: + """Check if file is an image we can return as base64.""" + ext = os.path.splitext(path)[1].lower() + return ext in IMAGE_EXTENSIONS + + def _add_line_numbers(self, content: str, start_line: int = 1) -> str: + """Add line numbers to content in LINE_NUM|CONTENT format.""" + lines = content.split('\n') + numbered = [] + for i, line in enumerate(lines, start=start_line): + # Truncate long lines + if len(line) > MAX_LINE_LENGTH: + line = line[:MAX_LINE_LENGTH] + "... [truncated]" + numbered.append(f"{i:6d}|{line}") + return '\n'.join(numbered) + + def _expand_path(self, path: str) -> str: + """ + Expand shell-style paths like ~ and ~user to absolute paths. + + This must be done BEFORE shell escaping, since ~ doesn't expand + inside single quotes. + """ + if not path: + return path + + # Handle ~ and ~user + if path.startswith('~'): + # Get home directory via the terminal environment + result = self._exec("echo $HOME") + if result.exit_code == 0 and result.stdout.strip(): + home = result.stdout.strip() + if path == '~': + return home + elif path.startswith('~/'): + return home + path[1:] # Replace ~ with home + # ~username format - let shell expand it + expand_result = self._exec(f"echo {path}") + if expand_result.exit_code == 0: + return expand_result.stdout.strip() + + return path + + def _escape_shell_arg(self, arg: str) -> str: + """Escape a string for safe use in shell commands.""" + # Use single quotes and escape any single quotes in the string + return "'" + arg.replace("'", "'\"'\"'") + "'" + + def _unified_diff(self, old_content: str, new_content: str, filename: str) -> str: + """Generate unified diff between old and new content.""" + old_lines = old_content.splitlines(keepends=True) + new_lines = new_content.splitlines(keepends=True) + diff = difflib.unified_diff( + old_lines, new_lines, + fromfile=f"a/{filename}", + tofile=f"b/{filename}" + ) + return ''.join(diff) + + # ========================================================================= + # READ Implementation + # ========================================================================= + + def read_file(self, path: str, offset: int = 1, limit: int = 500) -> ReadResult: + """ + Read a file with pagination, binary detection, and line numbers. + + Args: + path: File path (absolute or relative to cwd) + offset: Line number to start from (1-indexed, default 1) + limit: Maximum lines to return (default 500, max 2000) + + Returns: + ReadResult with content, metadata, or error info + """ + # Expand ~ and other shell paths + path = self._expand_path(path) + + # Clamp limit + limit = min(limit, MAX_LINES) + + # Check if file exists and get metadata + stat_cmd = f"stat -c '%s' {self._escape_shell_arg(path)} 2>/dev/null" + stat_result = self._exec(stat_cmd) + + if stat_result.exit_code != 0: + # File not found - try to suggest similar files + return self._suggest_similar_files(path) + + try: + file_size = int(stat_result.stdout.strip()) + except ValueError: + file_size = 0 + + # Check if file is too large + if file_size > MAX_FILE_SIZE: + # Still try to read, but warn + pass + + # Check if it's an image - return base64 + if self._is_image(path): + return self._read_image(path) + + # Read a sample to check for binary content + sample_cmd = f"head -c 1000 {self._escape_shell_arg(path)} 2>/dev/null" + sample_result = self._exec(sample_cmd) + + if self._is_likely_binary(path, sample_result.stdout): + return ReadResult( + is_binary=True, + file_size=file_size, + error="Binary file - cannot display as text. Use appropriate tools to handle this file type." + ) + + # Read with pagination using sed + end_line = offset + limit - 1 + read_cmd = f"sed -n '{offset},{end_line}p' {self._escape_shell_arg(path)}" + read_result = self._exec(read_cmd) + + if read_result.exit_code != 0: + return ReadResult(error=f"Failed to read file: {read_result.stdout}") + + # Get total line count + wc_cmd = f"wc -l < {self._escape_shell_arg(path)}" + wc_result = self._exec(wc_cmd) + try: + total_lines = int(wc_result.stdout.strip()) + except ValueError: + total_lines = 0 + + # Check if truncated + truncated = total_lines > end_line + hint = None + if truncated: + hint = f"Use offset={end_line + 1} to continue reading (showing {offset}-{end_line} of {total_lines} lines)" + + return ReadResult( + content=self._add_line_numbers(read_result.stdout, offset), + total_lines=total_lines, + file_size=file_size, + truncated=truncated, + hint=hint + ) + + def _read_image(self, path: str) -> ReadResult: + """Read an image file, returning base64 content.""" + # Get file size + stat_cmd = f"stat -c '%s' {self._escape_shell_arg(path)} 2>/dev/null" + stat_result = self._exec(stat_cmd) + try: + file_size = int(stat_result.stdout.strip()) + except ValueError: + file_size = 0 + + # Get base64 content + b64_cmd = f"base64 -w 0 {self._escape_shell_arg(path)} 2>/dev/null" + b64_result = self._exec(b64_cmd, timeout=30) + + if b64_result.exit_code != 0: + return ReadResult( + is_image=True, + is_binary=True, + file_size=file_size, + error=f"Failed to read image: {b64_result.stdout}" + ) + + # Try to get dimensions (requires ImageMagick) + dimensions = None + if self._has_command('identify'): + dim_cmd = f"identify -format '%wx%h' {self._escape_shell_arg(path)} 2>/dev/null" + dim_result = self._exec(dim_cmd) + if dim_result.exit_code == 0: + dimensions = dim_result.stdout.strip() + + # Determine MIME type from extension + ext = os.path.splitext(path)[1].lower() + mime_types = { + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.gif': 'image/gif', + '.webp': 'image/webp', + '.bmp': 'image/bmp', + '.ico': 'image/x-icon', + } + mime_type = mime_types.get(ext, 'application/octet-stream') + + return ReadResult( + is_image=True, + is_binary=True, + file_size=file_size, + base64_content=b64_result.stdout, + mime_type=mime_type, + dimensions=dimensions + ) + + def _suggest_similar_files(self, path: str) -> ReadResult: + """Suggest similar files when the requested file is not found.""" + # Get directory and filename + dir_path = os.path.dirname(path) or "." + filename = os.path.basename(path) + + # List files in directory + ls_cmd = f"ls -1 {self._escape_shell_arg(dir_path)} 2>/dev/null | head -20" + ls_result = self._exec(ls_cmd) + + similar = [] + if ls_result.exit_code == 0 and ls_result.stdout.strip(): + files = ls_result.stdout.strip().split('\n') + # Simple similarity: files that share some characters with the target + for f in files: + # Check if filenames share significant overlap + common = set(filename.lower()) & set(f.lower()) + if len(common) >= len(filename) * 0.5: # 50% character overlap + similar.append(os.path.join(dir_path, f)) + + return ReadResult( + error=f"File not found: {path}", + similar_files=similar[:5] # Limit to 5 suggestions + ) + + # ========================================================================= + # WRITE Implementation + # ========================================================================= + + def write_file(self, path: str, content: str) -> WriteResult: + """ + Write content to a file, creating parent directories as needed. + + Uses heredoc with unique marker for safe shell execution. + + Args: + path: File path to write + content: Content to write + + Returns: + WriteResult with bytes written or error + """ + # Expand ~ and other shell paths + path = self._expand_path(path) + + # Create parent directories + parent = os.path.dirname(path) + dirs_created = False + + if parent: + mkdir_cmd = f"mkdir -p {self._escape_shell_arg(parent)}" + mkdir_result = self._exec(mkdir_cmd) + if mkdir_result.exit_code == 0: + dirs_created = True + + # Generate unique marker for heredoc that won't appear in content + marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" + while marker in content: + marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" + + # Write using heredoc with single-quoted marker (prevents all expansion) + # The single quotes around the marker prevent variable expansion + write_cmd = f"cat > {self._escape_shell_arg(path)} << '{marker}'\n{content}\n{marker}" + write_result = self._exec(write_cmd) + + if write_result.exit_code != 0: + return WriteResult(error=f"Failed to write file: {write_result.stdout}") + + # Get bytes written + stat_cmd = f"stat -c '%s' {self._escape_shell_arg(path)} 2>/dev/null" + stat_result = self._exec(stat_cmd) + + try: + bytes_written = int(stat_result.stdout.strip()) + except ValueError: + bytes_written = len(content.encode('utf-8')) + + return WriteResult( + bytes_written=bytes_written, + dirs_created=dirs_created + ) + + # ========================================================================= + # PATCH Implementation (Replace Mode) + # ========================================================================= + + def patch_replace(self, path: str, old_string: str, new_string: str, + replace_all: bool = False) -> PatchResult: + """ + Replace text in a file using fuzzy matching. + + Args: + path: File path to modify + old_string: Text to find (must be unique unless replace_all=True) + new_string: Replacement text + replace_all: If True, replace all occurrences + + Returns: + PatchResult with diff and lint results + """ + # Expand ~ and other shell paths + path = self._expand_path(path) + + # Read current content + read_cmd = f"cat {self._escape_shell_arg(path)} 2>/dev/null" + read_result = self._exec(read_cmd) + + if read_result.exit_code != 0: + return PatchResult(error=f"Failed to read file: {path}") + + content = read_result.stdout + + # Import and use fuzzy matching + from tools.fuzzy_match import fuzzy_find_and_replace + + new_content, match_count, error = fuzzy_find_and_replace( + content, old_string, new_string, replace_all + ) + + if error: + return PatchResult(error=error) + + if match_count == 0: + return PatchResult(error=f"Could not find match for old_string in {path}") + + # Write back + write_result = self.write_file(path, new_content) + if write_result.error: + return PatchResult(error=f"Failed to write changes: {write_result.error}") + + # Generate diff + diff = self._unified_diff(content, new_content, path) + + # Auto-lint + lint_result = self._check_lint(path) + + return PatchResult( + success=True, + diff=diff, + files_modified=[path], + lint=lint_result.to_dict() if lint_result else None + ) + + def patch_v4a(self, patch_content: str) -> PatchResult: + """ + Apply a V4A format patch. + + V4A format: + *** Begin Patch + *** Update File: path/to/file.py + @@ context hint @@ + context line + -removed line + +added line + *** End Patch + + Args: + patch_content: V4A format patch string + + Returns: + PatchResult with changes made + """ + # Import patch parser + from tools.patch_parser import parse_v4a_patch, apply_v4a_operations + + operations, parse_error = parse_v4a_patch(patch_content) + if parse_error: + return PatchResult(error=f"Failed to parse patch: {parse_error}") + + # Apply operations + result = apply_v4a_operations(operations, self) + return result + + def _check_lint(self, path: str) -> LintResult: + """ + Run syntax check on a file after editing. + + Args: + path: File path to lint + + Returns: + LintResult with status and any errors + """ + ext = os.path.splitext(path)[1].lower() + + if ext not in LINTERS: + return LintResult(skipped=True, message=f"No linter for {ext} files") + + # Check if linter command is available + linter_cmd = LINTERS[ext] + # Extract the base command (first word) + base_cmd = linter_cmd.split()[0] + + if not self._has_command(base_cmd): + return LintResult(skipped=True, message=f"{base_cmd} not available") + + # Run linter + cmd = linter_cmd.format(file=self._escape_shell_arg(path)) + result = self._exec(cmd, timeout=30) + + return LintResult( + success=result.exit_code == 0, + output=result.stdout.strip() if result.stdout.strip() else "" + ) + + # ========================================================================= + # SEARCH Implementation + # ========================================================================= + + def search(self, pattern: str, path: str = ".", target: str = "content", + file_glob: Optional[str] = None, limit: int = 50, offset: int = 0, + output_mode: str = "content", context: int = 0) -> SearchResult: + """ + Search for content or files. + + Args: + pattern: Regex (for content) or glob pattern (for files) + path: Directory/file to search (default: cwd) + target: "content" (grep) or "files" (glob) + file_glob: File pattern filter for content search (e.g., "*.py") + limit: Max results (default 50) + offset: Skip first N results + output_mode: "content", "files_only", or "count" + context: Lines of context around matches + + Returns: + SearchResult with matches or file list + """ + # Expand ~ and other shell paths + path = self._expand_path(path) + + if target == "files": + return self._search_files(pattern, path, limit, offset) + else: + return self._search_content(pattern, path, file_glob, limit, offset, + output_mode, context) + + def _search_files(self, pattern: str, path: str, limit: int, offset: int) -> SearchResult: + """Search for files by name pattern (glob-like).""" + # Check if find is available (not on Windows without Git Bash/WSL) + if not self._has_command('find'): + return SearchResult( + error="File search requires 'find' command. " + "On Windows, use Git Bash, WSL, or install Unix tools." + ) + + # Auto-prepend **/ for recursive search if not already present + if not pattern.startswith('**/') and '/' not in pattern: + search_pattern = pattern + else: + search_pattern = pattern.split('/')[-1] + + # Use find with modification time sorting + # -printf '%T@ %p\n' outputs: timestamp path + # sort -rn sorts by timestamp descending (newest first) + cmd = f"find {self._escape_shell_arg(path)} -type f -name {self._escape_shell_arg(search_pattern)} " \ + f"-printf '%T@ %p\\n' 2>/dev/null | sort -rn | tail -n +{offset + 1} | head -n {limit}" + + result = self._exec(cmd, timeout=60) + + if result.exit_code != 0 and not result.stdout.strip(): + # Try without -printf (BSD find compatibility) + cmd_simple = f"find {self._escape_shell_arg(path)} -type f -name {self._escape_shell_arg(search_pattern)} " \ + f"2>/dev/null | head -n {limit + offset} | tail -n +{offset + 1}" + result = self._exec(cmd_simple, timeout=60) + + files = [] + for line in result.stdout.strip().split('\n'): + if not line: + continue + # Parse "timestamp path" format + parts = line.split(' ', 1) + if len(parts) == 2 and parts[0].replace('.', '').isdigit(): + files.append(parts[1]) + else: + files.append(line) + + return SearchResult( + files=files, + total_count=len(files) + ) + + def _search_content(self, pattern: str, path: str, file_glob: Optional[str], + limit: int, offset: int, output_mode: str, context: int) -> SearchResult: + """Search for content inside files (grep-like).""" + # Try ripgrep first (fast), fallback to grep (slower but works) + if self._has_command('rg'): + return self._search_with_rg(pattern, path, file_glob, limit, offset, + output_mode, context) + elif self._has_command('grep'): + return self._search_with_grep(pattern, path, file_glob, limit, offset, + output_mode, context) + else: + # Neither rg nor grep available (Windows without Git Bash, etc.) + return SearchResult( + error="Content search requires ripgrep (rg) or grep. " + "Install ripgrep: https://github.com/BurntSushi/ripgrep#installation" + ) + + def _search_with_rg(self, pattern: str, path: str, file_glob: Optional[str], + limit: int, offset: int, output_mode: str, context: int) -> SearchResult: + """Search using ripgrep.""" + cmd_parts = ["rg", "--line-number", "--no-heading"] + + # Add context if requested + if context > 0: + cmd_parts.extend(["-C", str(context)]) + + # Add file glob filter + if file_glob: + cmd_parts.extend(["--glob", file_glob]) + + # Output mode handling + if output_mode == "files_only": + cmd_parts.append("-l") # Files only + elif output_mode == "count": + cmd_parts.append("-c") # Count per file + + # Add pattern and path + cmd_parts.append(self._escape_shell_arg(pattern)) + cmd_parts.append(self._escape_shell_arg(path)) + + # Limit results + cmd_parts.extend(["|", "head", "-n", str(limit + offset)]) + + cmd = " ".join(cmd_parts) + result = self._exec(cmd, timeout=60) + + # Parse results based on output mode + if output_mode == "files_only": + files = [f for f in result.stdout.strip().split('\n') if f][offset:] + return SearchResult(files=files[:limit], total_count=len(files)) + + elif output_mode == "count": + counts = {} + for line in result.stdout.strip().split('\n'): + if ':' in line: + parts = line.rsplit(':', 1) + if len(parts) == 2: + try: + counts[parts[0]] = int(parts[1]) + except ValueError: + pass + return SearchResult(counts=counts, total_count=sum(counts.values())) + + else: + # Parse content matches + matches = [] + for line in result.stdout.strip().split('\n')[offset:]: + if not line: + continue + # Format: file:line:content + parts = line.split(':', 2) + if len(parts) >= 3: + try: + matches.append(SearchMatch( + path=parts[0], + line_number=int(parts[1]), + content=parts[2][:500] # Truncate long lines + )) + except ValueError: + # Line number not an int, skip + pass + + return SearchResult( + matches=matches[:limit], + total_count=len(matches), + truncated=len(matches) > limit + ) + + def _search_with_grep(self, pattern: str, path: str, file_glob: Optional[str], + limit: int, offset: int, output_mode: str, context: int) -> SearchResult: + """Fallback search using grep.""" + cmd_parts = ["grep", "-rn"] + + # Add context if requested + if context > 0: + cmd_parts.extend(["-C", str(context)]) + + # Add file pattern filter + if file_glob: + cmd_parts.extend(["--include", file_glob]) + + # Output mode handling + if output_mode == "files_only": + cmd_parts.append("-l") + elif output_mode == "count": + cmd_parts.append("-c") + + # Add pattern and path + cmd_parts.append(self._escape_shell_arg(pattern)) + cmd_parts.append(self._escape_shell_arg(path)) + + # Limit and offset + cmd_parts.extend(["|", "tail", "-n", f"+{offset + 1}", "|", "head", "-n", str(limit)]) + + cmd = " ".join(cmd_parts) + result = self._exec(cmd, timeout=60) + + # Parse results (same format as rg) + if output_mode == "files_only": + files = [f for f in result.stdout.strip().split('\n') if f] + return SearchResult(files=files, total_count=len(files)) + + elif output_mode == "count": + counts = {} + for line in result.stdout.strip().split('\n'): + if ':' in line: + parts = line.rsplit(':', 1) + if len(parts) == 2: + try: + counts[parts[0]] = int(parts[1]) + except ValueError: + pass + return SearchResult(counts=counts, total_count=sum(counts.values())) + + else: + matches = [] + for line in result.stdout.strip().split('\n'): + if not line: + continue + parts = line.split(':', 2) + if len(parts) >= 3: + try: + matches.append(SearchMatch( + path=parts[0], + line_number=int(parts[1]), + content=parts[2][:500] + )) + except ValueError: + pass + + return SearchResult( + matches=matches, + total_count=len(matches) + ) diff --git a/tools/file_tools.py b/tools/file_tools.py new file mode 100644 index 00000000..71704fba --- /dev/null +++ b/tools/file_tools.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +"""File Tools Module - LLM agent file manipulation tools.""" + +import json +import threading +from typing import Optional +from tools.file_operations import ShellFileOperations + +_file_ops_lock = threading.Lock() +_file_ops_cache: dict = {} + + +def _get_file_ops(task_id: str = "default") -> ShellFileOperations: + """Get or create ShellFileOperations for a terminal environment.""" + from tools.terminal_tool import _active_environments, _env_lock, _LocalEnvironment + + with _file_ops_lock: + if task_id in _file_ops_cache: + return _file_ops_cache[task_id] + + with _env_lock: + if task_id not in _active_environments: + import os + env = _LocalEnvironment(cwd=os.getcwd(), timeout=60) + _active_environments[task_id] = env + terminal_env = _active_environments[task_id] + + file_ops = ShellFileOperations(terminal_env) + _file_ops_cache[task_id] = file_ops + return file_ops + + +def clear_file_ops_cache(task_id: str = None): + """Clear the file operations cache.""" + with _file_ops_lock: + if task_id: + _file_ops_cache.pop(task_id, None) + else: + _file_ops_cache.clear() + + +def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = "default") -> str: + """Read a file with pagination and line numbers.""" + try: + file_ops = _get_file_ops(task_id) + result = file_ops.read_file(path, offset, limit) + return json.dumps(result.to_dict(), ensure_ascii=False) + except Exception as e: + return json.dumps({"error": str(e)}, ensure_ascii=False) + + +def write_file_tool(path: str, content: str, task_id: str = "default") -> str: + """Write content to a file.""" + try: + file_ops = _get_file_ops(task_id) + result = file_ops.write_file(path, content) + return json.dumps(result.to_dict(), ensure_ascii=False) + except Exception as e: + return json.dumps({"error": str(e)}, ensure_ascii=False) + + +def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, + new_string: str = None, replace_all: bool = False, patch: str = None, + task_id: str = "default") -> str: + """Patch a file using replace mode or V4A patch format.""" + try: + file_ops = _get_file_ops(task_id) + + if mode == "replace": + if not path: + return json.dumps({"error": "path required"}) + if old_string is None or new_string is None: + return json.dumps({"error": "old_string and new_string required"}) + result = file_ops.patch_replace(path, old_string, new_string, replace_all) + elif mode == "patch": + if not patch: + return json.dumps({"error": "patch content required"}) + result = file_ops.patch_v4a(patch) + else: + return json.dumps({"error": f"Unknown mode: {mode}"}) + + return json.dumps(result.to_dict(), ensure_ascii=False) + except Exception as e: + return json.dumps({"error": str(e)}, ensure_ascii=False) + + +def search_tool(pattern: str, target: str = "content", path: str = ".", + file_glob: str = None, limit: int = 50, offset: int = 0, + output_mode: str = "content", context: int = 0, + task_id: str = "default") -> str: + """Search for content or files.""" + try: + file_ops = _get_file_ops(task_id) + result = file_ops.search( + pattern=pattern, path=path, target=target, file_glob=file_glob, + limit=limit, offset=offset, output_mode=output_mode, context=context + ) + return json.dumps(result.to_dict(), ensure_ascii=False) + except Exception as e: + return json.dumps({"error": str(e)}, ensure_ascii=False) + + +FILE_TOOLS = [ + {"name": "read_file", "function": read_file_tool}, + {"name": "write_file", "function": write_file_tool}, + {"name": "patch", "function": patch_tool}, + {"name": "search", "function": search_tool} +] + + +def get_file_tools(): + """Get the list of file tool definitions.""" + return FILE_TOOLS diff --git a/tools/fuzzy_match.py b/tools/fuzzy_match.py new file mode 100644 index 00000000..796072ff --- /dev/null +++ b/tools/fuzzy_match.py @@ -0,0 +1,478 @@ +#!/usr/bin/env python3 +""" +Fuzzy Matching Module for File Operations + +Implements a multi-strategy matching chain to robustly find and replace text, +accommodating variations in whitespace, indentation, and escaping common +in LLM-generated code. + +The 9-strategy chain (inspired by OpenCode): +1. Exact match - Direct string comparison +2. Line-trimmed - Strip leading/trailing whitespace per line +3. Block anchor - Match first+last lines, use similarity for middle +4. Whitespace normalized - Collapse multiple spaces/tabs to single space +5. Indentation flexible - Ignore indentation differences entirely +6. Escape normalized - Convert \\n literals to actual newlines +7. Trimmed boundary - Trim first/last line whitespace only +8. Context-aware - 50% line similarity threshold +9. Multi-occurrence - For replace_all flag + +Usage: + from tools.fuzzy_match import fuzzy_find_and_replace + + new_content, match_count, error = fuzzy_find_and_replace( + content="def foo():\\n pass", + old_string="def foo():", + new_string="def bar():", + replace_all=False + ) +""" + +import re +from typing import Tuple, Optional, List, Callable +from difflib import SequenceMatcher + + +def fuzzy_find_and_replace(content: str, old_string: str, new_string: str, + replace_all: bool = False) -> Tuple[str, int, Optional[str]]: + """ + Find and replace text using a chain of increasingly fuzzy matching strategies. + + Args: + content: The file content to search in + old_string: The text to find + new_string: The replacement text + replace_all: If True, replace all occurrences; if False, require uniqueness + + Returns: + Tuple of (new_content, match_count, error_message) + - If successful: (modified_content, number_of_replacements, None) + - If failed: (original_content, 0, error_description) + """ + if not old_string: + return content, 0, "old_string cannot be empty" + + if old_string == new_string: + return content, 0, "old_string and new_string are identical" + + # Try each matching strategy in order + strategies: List[Tuple[str, Callable]] = [ + ("exact", _strategy_exact), + ("line_trimmed", _strategy_line_trimmed), + ("whitespace_normalized", _strategy_whitespace_normalized), + ("indentation_flexible", _strategy_indentation_flexible), + ("escape_normalized", _strategy_escape_normalized), + ("trimmed_boundary", _strategy_trimmed_boundary), + ("block_anchor", _strategy_block_anchor), + ("context_aware", _strategy_context_aware), + ] + + for strategy_name, strategy_fn in strategies: + matches = strategy_fn(content, old_string) + + if matches: + # Found matches with this strategy + if len(matches) > 1 and not replace_all: + return content, 0, ( + f"Found {len(matches)} matches for old_string. " + f"Provide more context to make it unique, or use replace_all=True." + ) + + # Perform replacement + new_content = _apply_replacements(content, matches, new_string) + return new_content, len(matches), None + + # No strategy found a match + return content, 0, "Could not find a match for old_string in the file" + + +def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string: str) -> str: + """ + Apply replacements at the given positions. + + Args: + content: Original content + matches: List of (start, end) positions to replace + new_string: Replacement text + + Returns: + Content with replacements applied + """ + # Sort matches by position (descending) to replace from end to start + # This preserves positions of earlier matches + sorted_matches = sorted(matches, key=lambda x: x[0], reverse=True) + + result = content + for start, end in sorted_matches: + result = result[:start] + new_string + result[end:] + + return result + + +# ============================================================================= +# Matching Strategies +# ============================================================================= + +def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]: + """Strategy 1: Exact string match.""" + matches = [] + start = 0 + while True: + pos = content.find(pattern, start) + if pos == -1: + break + matches.append((pos, pos + len(pattern))) + start = pos + 1 + return matches + + +def _strategy_line_trimmed(content: str, pattern: str) -> List[Tuple[int, int]]: + """ + Strategy 2: Match with line-by-line whitespace trimming. + + Strips leading/trailing whitespace from each line before matching. + """ + # Normalize pattern and content by trimming each line + pattern_lines = [line.strip() for line in pattern.split('\n')] + pattern_normalized = '\n'.join(pattern_lines) + + content_lines = content.split('\n') + content_normalized_lines = [line.strip() for line in content_lines] + + # Build mapping from normalized positions back to original positions + return _find_normalized_matches( + content, content_lines, content_normalized_lines, + pattern, pattern_normalized + ) + + +def _strategy_whitespace_normalized(content: str, pattern: str) -> List[Tuple[int, int]]: + """ + Strategy 3: Collapse multiple whitespace to single space. + """ + def normalize(s): + # Collapse multiple spaces/tabs to single space, preserve newlines + return re.sub(r'[ \t]+', ' ', s) + + pattern_normalized = normalize(pattern) + content_normalized = normalize(content) + + # Find in normalized, map back to original + matches_in_normalized = _strategy_exact(content_normalized, pattern_normalized) + + if not matches_in_normalized: + return [] + + # Map positions back to original content + return _map_normalized_positions(content, content_normalized, matches_in_normalized) + + +def _strategy_indentation_flexible(content: str, pattern: str) -> List[Tuple[int, int]]: + """ + Strategy 4: Ignore indentation differences entirely. + + Strips all leading whitespace from lines before matching. + """ + def strip_indent(s): + return '\n'.join(line.lstrip() for line in s.split('\n')) + + pattern_stripped = strip_indent(pattern) + + content_lines = content.split('\n') + content_stripped_lines = [line.lstrip() for line in content_lines] + pattern_lines = [line.lstrip() for line in pattern.split('\n')] + + return _find_normalized_matches( + content, content_lines, content_stripped_lines, + pattern, '\n'.join(pattern_lines) + ) + + +def _strategy_escape_normalized(content: str, pattern: str) -> List[Tuple[int, int]]: + """ + Strategy 5: Convert escape sequences to actual characters. + + Handles \\n -> newline, \\t -> tab, etc. + """ + def unescape(s): + # Convert common escape sequences + return s.replace('\\n', '\n').replace('\\t', '\t').replace('\\r', '\r') + + pattern_unescaped = unescape(pattern) + + if pattern_unescaped == pattern: + # No escapes to convert, skip this strategy + return [] + + return _strategy_exact(content, pattern_unescaped) + + +def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, int]]: + """ + Strategy 6: Trim whitespace from first and last lines only. + + Useful when the pattern boundaries have whitespace differences. + """ + pattern_lines = pattern.split('\n') + if not pattern_lines: + return [] + + # Trim only first and last lines + pattern_lines[0] = pattern_lines[0].strip() + if len(pattern_lines) > 1: + pattern_lines[-1] = pattern_lines[-1].strip() + + modified_pattern = '\n'.join(pattern_lines) + + content_lines = content.split('\n') + + # Search through content for matching block + matches = [] + pattern_line_count = len(pattern_lines) + + for i in range(len(content_lines) - pattern_line_count + 1): + block_lines = content_lines[i:i + pattern_line_count] + + # Trim first and last of this block + check_lines = block_lines.copy() + check_lines[0] = check_lines[0].strip() + if len(check_lines) > 1: + check_lines[-1] = check_lines[-1].strip() + + if '\n'.join(check_lines) == modified_pattern: + # Found match - calculate original positions + start_pos = sum(len(line) + 1 for line in content_lines[:i]) + end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1 + if end_pos >= len(content): + end_pos = len(content) + matches.append((start_pos, end_pos)) + + return matches + + +def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]: + """ + Strategy 7: Match by anchoring on first and last lines. + + If first and last lines match exactly, accept middle with 70% similarity. + """ + pattern_lines = pattern.split('\n') + if len(pattern_lines) < 2: + return [] # Need at least 2 lines for anchoring + + first_line = pattern_lines[0].strip() + last_line = pattern_lines[-1].strip() + + content_lines = content.split('\n') + matches = [] + + pattern_line_count = len(pattern_lines) + + for i in range(len(content_lines) - pattern_line_count + 1): + # Check if first and last lines match + if (content_lines[i].strip() == first_line and + content_lines[i + pattern_line_count - 1].strip() == last_line): + + # Check middle similarity + if pattern_line_count <= 2: + # Only first and last, they match + similarity = 1.0 + else: + content_middle = '\n'.join(content_lines[i+1:i+pattern_line_count-1]) + pattern_middle = '\n'.join(pattern_lines[1:-1]) + similarity = SequenceMatcher(None, content_middle, pattern_middle).ratio() + + if similarity >= 0.70: + # Calculate positions + start_pos = sum(len(line) + 1 for line in content_lines[:i]) + end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1 + if end_pos >= len(content): + end_pos = len(content) + matches.append((start_pos, end_pos)) + + return matches + + +def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]: + """ + Strategy 8: Line-by-line similarity with 50% threshold. + + Finds blocks where at least 50% of lines have high similarity. + """ + pattern_lines = pattern.split('\n') + content_lines = content.split('\n') + + if not pattern_lines: + return [] + + matches = [] + pattern_line_count = len(pattern_lines) + + for i in range(len(content_lines) - pattern_line_count + 1): + block_lines = content_lines[i:i + pattern_line_count] + + # Calculate line-by-line similarity + high_similarity_count = 0 + for p_line, c_line in zip(pattern_lines, block_lines): + sim = SequenceMatcher(None, p_line.strip(), c_line.strip()).ratio() + if sim >= 0.80: + high_similarity_count += 1 + + # Need at least 50% of lines to have high similarity + if high_similarity_count >= len(pattern_lines) * 0.5: + start_pos = sum(len(line) + 1 for line in content_lines[:i]) + end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1 + if end_pos >= len(content): + end_pos = len(content) + matches.append((start_pos, end_pos)) + + return matches + + +# ============================================================================= +# Helper Functions +# ============================================================================= + +def _find_normalized_matches(content: str, content_lines: List[str], + content_normalized_lines: List[str], + pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]: + """ + Find matches in normalized content and map back to original positions. + + Args: + content: Original content string + content_lines: Original content split by lines + content_normalized_lines: Normalized content lines + pattern: Original pattern + pattern_normalized: Normalized pattern + + Returns: + List of (start, end) positions in the original content + """ + pattern_norm_lines = pattern_normalized.split('\n') + num_pattern_lines = len(pattern_norm_lines) + + matches = [] + + for i in range(len(content_normalized_lines) - num_pattern_lines + 1): + # Check if this block matches + block = '\n'.join(content_normalized_lines[i:i + num_pattern_lines]) + + if block == pattern_normalized: + # Found a match - calculate original positions + start_pos = sum(len(line) + 1 for line in content_lines[:i]) + end_pos = sum(len(line) + 1 for line in content_lines[:i + num_pattern_lines]) - 1 + + # Handle case where end is past content + if end_pos >= len(content): + end_pos = len(content) + + matches.append((start_pos, end_pos)) + + return matches + + +def _map_normalized_positions(original: str, normalized: str, + normalized_matches: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + """ + Map positions from normalized string back to original. + + This is a best-effort mapping that works for whitespace normalization. + """ + if not normalized_matches: + return [] + + # Build character mapping from normalized to original + orig_to_norm = [] # orig_to_norm[i] = position in normalized + + orig_idx = 0 + norm_idx = 0 + + while orig_idx < len(original) and norm_idx < len(normalized): + if original[orig_idx] == normalized[norm_idx]: + orig_to_norm.append(norm_idx) + orig_idx += 1 + norm_idx += 1 + elif original[orig_idx] in ' \t' and normalized[norm_idx] == ' ': + # Original has space/tab, normalized collapsed to space + orig_to_norm.append(norm_idx) + orig_idx += 1 + # Don't advance norm_idx yet - wait until all whitespace consumed + if orig_idx < len(original) and original[orig_idx] not in ' \t': + norm_idx += 1 + elif original[orig_idx] in ' \t': + # Extra whitespace in original + orig_to_norm.append(norm_idx) + orig_idx += 1 + else: + # Mismatch - shouldn't happen with our normalization + orig_to_norm.append(norm_idx) + orig_idx += 1 + + # Fill remaining + while orig_idx < len(original): + orig_to_norm.append(len(normalized)) + orig_idx += 1 + + # Reverse mapping: for each normalized position, find original range + norm_to_orig_start = {} + norm_to_orig_end = {} + + for orig_pos, norm_pos in enumerate(orig_to_norm): + if norm_pos not in norm_to_orig_start: + norm_to_orig_start[norm_pos] = orig_pos + norm_to_orig_end[norm_pos] = orig_pos + + # Map matches + original_matches = [] + for norm_start, norm_end in normalized_matches: + # Find original start + if norm_start in norm_to_orig_start: + orig_start = norm_to_orig_start[norm_start] + else: + # Find nearest + orig_start = min(i for i, n in enumerate(orig_to_norm) if n >= norm_start) + + # Find original end + if norm_end - 1 in norm_to_orig_end: + orig_end = norm_to_orig_end[norm_end - 1] + 1 + else: + orig_end = orig_start + (norm_end - norm_start) + + # Expand to include trailing whitespace that was normalized + while orig_end < len(original) and original[orig_end] in ' \t': + orig_end += 1 + + original_matches.append((orig_start, min(orig_end, len(original)))) + + return original_matches + + +# ============================================================================= +# Utility Functions +# ============================================================================= + +def find_best_match(content: str, pattern: str) -> Optional[Tuple[int, int, str]]: + """ + Find the best match for a pattern and return the strategy name. + + Returns: + Tuple of (start, end, strategy_name) or None if no match + """ + strategies = [ + ("exact", _strategy_exact), + ("line_trimmed", _strategy_line_trimmed), + ("whitespace_normalized", _strategy_whitespace_normalized), + ("indentation_flexible", _strategy_indentation_flexible), + ("escape_normalized", _strategy_escape_normalized), + ("trimmed_boundary", _strategy_trimmed_boundary), + ("block_anchor", _strategy_block_anchor), + ("context_aware", _strategy_context_aware), + ] + + for strategy_name, strategy_fn in strategies: + matches = strategy_fn(content, pattern) + if matches: + return (matches[0][0], matches[0][1], strategy_name) + + return None diff --git a/tools/patch_parser.py b/tools/patch_parser.py new file mode 100644 index 00000000..bce7bb6e --- /dev/null +++ b/tools/patch_parser.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python3 +""" +V4A Patch Format Parser + +Parses the V4A patch format used by codex, cline, and other coding agents. + +V4A Format: + *** Begin Patch + *** Update File: path/to/file.py + @@ optional context hint @@ + context line (space prefix) + -removed line (minus prefix) + +added line (plus prefix) + *** Add File: path/to/new.py + +new file content + +line 2 + *** Delete File: path/to/old.py + *** Move File: old/path.py -> new/path.py + *** End Patch + +Usage: + from tools.patch_parser import parse_v4a_patch, apply_v4a_operations + + operations, error = parse_v4a_patch(patch_content) + if error: + print(f"Parse error: {error}") + else: + result = apply_v4a_operations(operations, file_ops) +""" + +import re +from dataclasses import dataclass, field +from typing import List, Optional, Tuple, Any +from enum import Enum + + +class OperationType(Enum): + ADD = "add" + UPDATE = "update" + DELETE = "delete" + MOVE = "move" + + +@dataclass +class HunkLine: + """A single line in a patch hunk.""" + prefix: str # ' ', '-', or '+' + content: str + + +@dataclass +class Hunk: + """A group of changes within a file.""" + context_hint: Optional[str] = None + lines: List[HunkLine] = field(default_factory=list) + + +@dataclass +class PatchOperation: + """A single operation in a V4A patch.""" + operation: OperationType + file_path: str + new_path: Optional[str] = None # For move operations + hunks: List[Hunk] = field(default_factory=list) + content: Optional[str] = None # For add file operations + + +def parse_v4a_patch(patch_content: str) -> Tuple[List[PatchOperation], Optional[str]]: + """ + Parse a V4A format patch. + + Args: + patch_content: The patch text in V4A format + + Returns: + Tuple of (operations, error_message) + - If successful: (list_of_operations, None) + - If failed: ([], error_description) + """ + lines = patch_content.split('\n') + operations: List[PatchOperation] = [] + + # Find patch boundaries + start_idx = None + end_idx = None + + for i, line in enumerate(lines): + if '*** Begin Patch' in line or '***Begin Patch' in line: + start_idx = i + elif '*** End Patch' in line or '***End Patch' in line: + end_idx = i + break + + if start_idx is None: + # Try to parse without explicit begin marker + start_idx = -1 + + if end_idx is None: + end_idx = len(lines) + + # Parse operations between boundaries + i = start_idx + 1 + current_op: Optional[PatchOperation] = None + current_hunk: Optional[Hunk] = None + + while i < end_idx: + line = lines[i] + + # Check for file operation markers + update_match = re.match(r'\*\*\*\s*Update\s+File:\s*(.+)', line) + add_match = re.match(r'\*\*\*\s*Add\s+File:\s*(.+)', line) + delete_match = re.match(r'\*\*\*\s*Delete\s+File:\s*(.+)', line) + move_match = re.match(r'\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)', line) + + if update_match: + # Save previous operation + if current_op: + if current_hunk and current_hunk.lines: + current_op.hunks.append(current_hunk) + operations.append(current_op) + + current_op = PatchOperation( + operation=OperationType.UPDATE, + file_path=update_match.group(1).strip() + ) + current_hunk = None + + elif add_match: + if current_op: + if current_hunk and current_hunk.lines: + current_op.hunks.append(current_hunk) + operations.append(current_op) + + current_op = PatchOperation( + operation=OperationType.ADD, + file_path=add_match.group(1).strip() + ) + current_hunk = Hunk() + + elif delete_match: + if current_op: + if current_hunk and current_hunk.lines: + current_op.hunks.append(current_hunk) + operations.append(current_op) + + current_op = PatchOperation( + operation=OperationType.DELETE, + file_path=delete_match.group(1).strip() + ) + operations.append(current_op) + current_op = None + current_hunk = None + + elif move_match: + if current_op: + if current_hunk and current_hunk.lines: + current_op.hunks.append(current_hunk) + operations.append(current_op) + + current_op = PatchOperation( + operation=OperationType.MOVE, + file_path=move_match.group(1).strip(), + new_path=move_match.group(2).strip() + ) + operations.append(current_op) + current_op = None + current_hunk = None + + elif line.startswith('@@'): + # Context hint / hunk marker + if current_op: + if current_hunk and current_hunk.lines: + current_op.hunks.append(current_hunk) + + # Extract context hint + hint_match = re.match(r'@@\s*(.+?)\s*@@', line) + hint = hint_match.group(1) if hint_match else None + current_hunk = Hunk(context_hint=hint) + + elif current_op and line: + # Parse hunk line + if current_hunk is None: + current_hunk = Hunk() + + if line.startswith('+'): + current_hunk.lines.append(HunkLine('+', line[1:])) + elif line.startswith('-'): + current_hunk.lines.append(HunkLine('-', line[1:])) + elif line.startswith(' '): + current_hunk.lines.append(HunkLine(' ', line[1:])) + elif line.startswith('\\'): + # "\ No newline at end of file" marker - skip + pass + else: + # Treat as context line (implicit space prefix) + current_hunk.lines.append(HunkLine(' ', line)) + + i += 1 + + # Don't forget the last operation + if current_op: + if current_hunk and current_hunk.lines: + current_op.hunks.append(current_hunk) + operations.append(current_op) + + return operations, None + + +def apply_v4a_operations(operations: List[PatchOperation], + file_ops: Any) -> 'PatchResult': + """ + Apply V4A patch operations using a file operations interface. + + Args: + operations: List of PatchOperation from parse_v4a_patch + file_ops: Object with read_file, write_file methods + + Returns: + PatchResult with results of all operations + """ + # Import here to avoid circular imports + from tools.file_operations import PatchResult + + files_modified = [] + files_created = [] + files_deleted = [] + all_diffs = [] + errors = [] + + for op in operations: + try: + if op.operation == OperationType.ADD: + result = _apply_add(op, file_ops) + if result[0]: + files_created.append(op.file_path) + all_diffs.append(result[1]) + else: + errors.append(f"Failed to add {op.file_path}: {result[1]}") + + elif op.operation == OperationType.DELETE: + result = _apply_delete(op, file_ops) + if result[0]: + files_deleted.append(op.file_path) + all_diffs.append(result[1]) + else: + errors.append(f"Failed to delete {op.file_path}: {result[1]}") + + elif op.operation == OperationType.MOVE: + result = _apply_move(op, file_ops) + if result[0]: + files_modified.append(f"{op.file_path} -> {op.new_path}") + all_diffs.append(result[1]) + else: + errors.append(f"Failed to move {op.file_path}: {result[1]}") + + elif op.operation == OperationType.UPDATE: + result = _apply_update(op, file_ops) + if result[0]: + files_modified.append(op.file_path) + all_diffs.append(result[1]) + else: + errors.append(f"Failed to update {op.file_path}: {result[1]}") + + except Exception as e: + errors.append(f"Error processing {op.file_path}: {str(e)}") + + # Run lint on all modified/created files + lint_results = {} + for f in files_modified + files_created: + if hasattr(file_ops, '_check_lint'): + lint_result = file_ops._check_lint(f) + lint_results[f] = lint_result.to_dict() + + combined_diff = '\n'.join(all_diffs) + + if errors: + return PatchResult( + success=False, + diff=combined_diff, + files_modified=files_modified, + files_created=files_created, + files_deleted=files_deleted, + lint=lint_results if lint_results else None, + error='; '.join(errors) + ) + + return PatchResult( + success=True, + diff=combined_diff, + files_modified=files_modified, + files_created=files_created, + files_deleted=files_deleted, + lint=lint_results if lint_results else None + ) + + +def _apply_add(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]: + """Apply an add file operation.""" + # Extract content from hunks (all + lines) + content_lines = [] + for hunk in op.hunks: + for line in hunk.lines: + if line.prefix == '+': + content_lines.append(line.content) + + content = '\n'.join(content_lines) + + result = file_ops.write_file(op.file_path, content) + if result.error: + return False, result.error + + diff = f"--- /dev/null\n+++ b/{op.file_path}\n" + diff += '\n'.join(f"+{line}" for line in content_lines) + + return True, diff + + +def _apply_delete(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]: + """Apply a delete file operation.""" + # Read file first for diff + read_result = file_ops.read_file(op.file_path) + + if read_result.error and "not found" in read_result.error.lower(): + # File doesn't exist, nothing to delete + return True, f"# {op.file_path} already deleted or doesn't exist" + + # Delete by writing empty and then removing + # Use shell command via the underlying environment + rm_result = file_ops._exec(f"rm -f {file_ops._escape_shell_arg(op.file_path)}") + + if rm_result.exit_code != 0: + return False, rm_result.stdout + + diff = f"--- a/{op.file_path}\n+++ /dev/null\n# File deleted" + return True, diff + + +def _apply_move(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]: + """Apply a move file operation.""" + # Use shell mv command + mv_result = file_ops._exec( + f"mv {file_ops._escape_shell_arg(op.file_path)} {file_ops._escape_shell_arg(op.new_path)}" + ) + + if mv_result.exit_code != 0: + return False, mv_result.stdout + + diff = f"# Moved: {op.file_path} -> {op.new_path}" + return True, diff + + +def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]: + """Apply an update file operation.""" + # Read current content + read_result = file_ops.read_file(op.file_path, limit=10000) + + if read_result.error: + return False, f"Cannot read file: {read_result.error}" + + # Parse content (remove line numbers) + current_lines = [] + for line in read_result.content.split('\n'): + if '|' in line: + # Line format: " 123|content" + parts = line.split('|', 1) + if len(parts) == 2: + current_lines.append(parts[1]) + else: + current_lines.append(line) + else: + current_lines.append(line) + + current_content = '\n'.join(current_lines) + + # Apply each hunk + new_content = current_content + + for hunk in op.hunks: + # Build search pattern from context and removed lines + search_lines = [] + replace_lines = [] + + for line in hunk.lines: + if line.prefix == ' ': + search_lines.append(line.content) + replace_lines.append(line.content) + elif line.prefix == '-': + search_lines.append(line.content) + elif line.prefix == '+': + replace_lines.append(line.content) + + if search_lines: + search_pattern = '\n'.join(search_lines) + replacement = '\n'.join(replace_lines) + + # Use fuzzy matching + from tools.fuzzy_match import fuzzy_find_and_replace + new_content, count, error = fuzzy_find_and_replace( + new_content, search_pattern, replacement, replace_all=False + ) + + if error and count == 0: + # Try with context hint if available + if hunk.context_hint: + # Find the context hint location and search nearby + hint_pos = new_content.find(hunk.context_hint) + if hint_pos != -1: + # Search in a window around the hint + window_start = max(0, hint_pos - 500) + window_end = min(len(new_content), hint_pos + 2000) + window = new_content[window_start:window_end] + + window_new, count, error = fuzzy_find_and_replace( + window, search_pattern, replacement, replace_all=False + ) + + if count > 0: + new_content = new_content[:window_start] + window_new + new_content[window_end:] + error = None + + if error: + return False, f"Could not apply hunk: {error}" + + # Write new content + write_result = file_ops.write_file(op.file_path, new_content) + if write_result.error: + return False, write_result.error + + # Generate diff + import difflib + diff_lines = difflib.unified_diff( + current_content.splitlines(keepends=True), + new_content.splitlines(keepends=True), + fromfile=f"a/{op.file_path}", + tofile=f"b/{op.file_path}" + ) + diff = ''.join(diff_lines) + + return True, diff diff --git a/toolset_distributions.py b/toolset_distributions.py index 7eb5980a..7f829c27 100644 --- a/toolset_distributions.py +++ b/toolset_distributions.py @@ -35,6 +35,7 @@ DISTRIBUTIONS = { "vision": 100, "image_gen": 100, "terminal": 100, + "file": 100, "moa": 100, "browser": 100 } @@ -66,10 +67,11 @@ DISTRIBUTIONS = { # Scientific problem solving focused distribution "science": { - "description": "Scientific research with web, terminal, and browser capabilities", + "description": "Scientific research with web, terminal, file, and browser capabilities", "toolsets": { "web": 94, # 94% chance of web tools "terminal": 94, # 94% chance of terminal tools + "file": 94, # 94% chance of file tools "vision": 65, # 65% chance of vision tools "browser": 50, # 50% chance of browser for accessing papers/databases "image_gen": 15, # 15% chance of image generation tools @@ -79,9 +81,10 @@ DISTRIBUTIONS = { # Development-focused distribution "development": { - "description": "Terminal and reasoning with occasional web lookup", + "description": "Terminal, file tools, and reasoning with occasional web lookup", "toolsets": { "terminal": 80, # 80% chance of terminal tools + "file": 80, # 80% chance of file tools (read, write, patch, search) "moa": 60, # 60% chance of reasoning tools "web": 30, # 30% chance of web tools "vision": 10 # 10% chance of vision tools @@ -108,6 +111,7 @@ DISTRIBUTIONS = { "vision": 50, "image_gen": 50, "terminal": 50, + "file": 50, "moa": 50, "browser": 50 } @@ -123,17 +127,19 @@ DISTRIBUTIONS = { # Terminal only "terminal_only": { - "description": "Only terminal tool for code execution tasks", + "description": "Terminal and file tools for code execution tasks", "toolsets": { - "terminal": 100 + "terminal": 100, + "file": 100 } }, # Terminal + web (common for coding tasks that need docs) "terminal_web": { - "description": "Terminal with web search for documentation lookup", + "description": "Terminal and file tools with web search for documentation lookup", "toolsets": { "terminal": 100, + "file": 100, "web": 100 } }, @@ -188,9 +194,10 @@ DISTRIBUTIONS = { # Terminal-focused tasks distribution (for nous-terminal-tasks.jsonl) "terminal_tasks": { - "description": "Terminal-focused distribution with high terminal availability, occasional other tools", + "description": "Terminal-focused distribution with high terminal/file availability, occasional other tools", "toolsets": { "terminal": 97, # 97% - terminal almost always available + "file": 97, # 97% - file tools almost always available "web": 15, # 15% - web search/scrape for documentation "browser": 10, # 10% - browser occasionally for web interaction "vision": 8, # 8% - vision analysis rarely @@ -200,10 +207,11 @@ DISTRIBUTIONS = { # Mixed browser+terminal tasks distribution (for mixed-browser-terminal-tasks.jsonl) "mixed_tasks": { - "description": "Mixed distribution with high browser and terminal availability for complex tasks", + "description": "Mixed distribution with high browser, terminal, and file availability for complex tasks", "toolsets": { "browser": 92, # 92% - browser tools highly available - "terminal": 92, # 92% - terminal highly available + "terminal": 92, # 92% - terminal highly available + "file": 92, # 92% - file tools highly available "web": 35, # 35% - web search/scrape fairly common "vision": 15, # 15% - vision analysis occasionally "image_gen": 15 # 15% - image generation occasionally diff --git a/toolsets.py b/toolsets.py index abd6192a..7dac5ff1 100644 --- a/toolsets.py +++ b/toolsets.py @@ -102,12 +102,18 @@ TOOLSETS = { "includes": [] }, + "file": { + "description": "File manipulation tools: read, write, patch (with fuzzy matching), and search (content + files)", + "tools": ["read_file", "write_file", "patch", "search"], + "includes": [] + }, + # Scenario-specific toolsets "debugging": { "description": "Debugging and troubleshooting toolkit", "tools": ["terminal"], - "includes": ["web"] # For searching error messages and solutions + "includes": ["web", "file"] # For searching error messages and solutions, and file operations }, "safe": { @@ -127,6 +133,8 @@ TOOLSETS = { "web_search", "web_extract", # Terminal "terminal", + # File manipulation + "read_file", "write_file", "patch", "search", # Vision "vision_analyze", # Image generation @@ -155,6 +163,8 @@ TOOLSETS = { "tools": [ # Terminal - enabled with dangerous command approval system "terminal", + # File manipulation + "read_file", "write_file", "patch", "search", # Web tools "web_search", "web_extract", # Vision - analyze images sent by users @@ -189,6 +199,8 @@ TOOLSETS = { "web_search", "web_extract", # Terminal - only for trusted personal accounts "terminal", + # File manipulation + "read_file", "write_file", "patch", "search", # Vision "vision_analyze", # Skills