Release set of skills
This commit is contained in:
parent
757d012ab5
commit
740dd928f7
96 changed files with 2040 additions and 5300 deletions
|
|
@ -1,434 +0,0 @@
|
|||
---
|
||||
name: peft-fine-tuning
|
||||
description: Parameter-efficient fine-tuning for LLMs using LoRA, QLoRA, and 25+ methods. Use when fine-tuning large models (7B-70B) with limited GPU memory, when you need to train <1% of parameters with minimal accuracy loss, or for multi-adapter serving. HuggingFace's official library integrated with transformers ecosystem.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [peft>=0.13.0, transformers>=4.45.0, torch>=2.0.0, bitsandbytes>=0.43.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Fine-Tuning, PEFT, LoRA, QLoRA, Parameter-Efficient, Adapters, Low-Rank, Memory Optimization, Multi-Adapter]
|
||||
|
||||
---
|
||||
|
||||
# PEFT (Parameter-Efficient Fine-Tuning)
|
||||
|
||||
Fine-tune LLMs by training <1% of parameters using LoRA, QLoRA, and 25+ adapter methods.
|
||||
|
||||
## When to use PEFT
|
||||
|
||||
**Use PEFT/LoRA when:**
|
||||
- Fine-tuning 7B-70B models on consumer GPUs (RTX 4090, A100)
|
||||
- Need to train <1% parameters (6MB adapters vs 14GB full model)
|
||||
- Want fast iteration with multiple task-specific adapters
|
||||
- Deploying multiple fine-tuned variants from one base model
|
||||
|
||||
**Use QLoRA (PEFT + quantization) when:**
|
||||
- Fine-tuning 70B models on single 24GB GPU
|
||||
- Memory is the primary constraint
|
||||
- Can accept ~5% quality trade-off vs full fine-tuning
|
||||
|
||||
**Use full fine-tuning instead when:**
|
||||
- Training small models (<1B parameters)
|
||||
- Need maximum quality and have compute budget
|
||||
- Significant domain shift requires updating all weights
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Basic installation
|
||||
pip install peft
|
||||
|
||||
# With quantization support (recommended)
|
||||
pip install peft bitsandbytes
|
||||
|
||||
# Full stack
|
||||
pip install peft transformers accelerate bitsandbytes datasets
|
||||
```
|
||||
|
||||
### LoRA fine-tuning (standard)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
|
||||
from peft import get_peft_model, LoraConfig, TaskType
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load base model
|
||||
model_name = "meta-llama/Llama-3.1-8B"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# LoRA configuration
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=16, # Rank (8-64, higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
lora_dropout=0.05, # Dropout for regularization
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Attention layers
|
||||
bias="none" # Don't train biases
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
# Output: trainable params: 13,631,488 || all params: 8,043,307,008 || trainable%: 0.17%
|
||||
|
||||
# Prepare dataset
|
||||
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
|
||||
|
||||
def tokenize(example):
|
||||
text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['response']}"
|
||||
return tokenizer(text, truncation=True, max_length=512, padding="max_length")
|
||||
|
||||
tokenized = dataset.map(tokenize, remove_columns=dataset.column_names)
|
||||
|
||||
# Training
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./lora-llama",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
fp16=True,
|
||||
logging_steps=10,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized,
|
||||
data_collator=lambda data: {"input_ids": torch.stack([f["input_ids"] for f in data]),
|
||||
"attention_mask": torch.stack([f["attention_mask"] for f in data]),
|
||||
"labels": torch.stack([f["input_ids"] for f in data])}
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save adapter only (6MB vs 16GB)
|
||||
model.save_pretrained("./lora-llama-adapter")
|
||||
```
|
||||
|
||||
### QLoRA fine-tuning (memory-efficient)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
|
||||
|
||||
# 4-bit quantization config
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4", # NormalFloat4 (best for LLMs)
|
||||
bnb_4bit_compute_dtype="bfloat16", # Compute in bf16
|
||||
bnb_4bit_use_double_quant=True # Nested quantization
|
||||
)
|
||||
|
||||
# Load quantized model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-70B",
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
# Prepare for training (enables gradient checkpointing)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# LoRA config for QLoRA
|
||||
lora_config = LoraConfig(
|
||||
r=64, # Higher rank for 70B
|
||||
lora_alpha=128,
|
||||
lora_dropout=0.1,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
# 70B model now fits on single 24GB GPU!
|
||||
```
|
||||
|
||||
## LoRA parameter selection
|
||||
|
||||
### Rank (r) - capacity vs efficiency
|
||||
|
||||
| Rank | Trainable Params | Memory | Quality | Use Case |
|
||||
|------|-----------------|--------|---------|----------|
|
||||
| 4 | ~3M | Minimal | Lower | Simple tasks, prototyping |
|
||||
| **8** | ~7M | Low | Good | **Recommended starting point** |
|
||||
| **16** | ~14M | Medium | Better | **General fine-tuning** |
|
||||
| 32 | ~27M | Higher | High | Complex tasks |
|
||||
| 64 | ~54M | High | Highest | Domain adaptation, 70B models |
|
||||
|
||||
### Alpha (lora_alpha) - scaling factor
|
||||
|
||||
```python
|
||||
# Rule of thumb: alpha = 2 * rank
|
||||
LoraConfig(r=16, lora_alpha=32) # Standard
|
||||
LoraConfig(r=16, lora_alpha=16) # Conservative (lower learning rate effect)
|
||||
LoraConfig(r=16, lora_alpha=64) # Aggressive (higher learning rate effect)
|
||||
```
|
||||
|
||||
### Target modules by architecture
|
||||
|
||||
```python
|
||||
# Llama / Mistral / Qwen
|
||||
target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
|
||||
# GPT-2 / GPT-Neo
|
||||
target_modules = ["c_attn", "c_proj", "c_fc"]
|
||||
|
||||
# Falcon
|
||||
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
|
||||
|
||||
# BLOOM
|
||||
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
|
||||
|
||||
# Auto-detect all linear layers
|
||||
target_modules = "all-linear" # PEFT 0.6.0+
|
||||
```
|
||||
|
||||
## Loading and merging adapters
|
||||
|
||||
### Load trained adapter
|
||||
|
||||
```python
|
||||
from peft import PeftModel, AutoPeftModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Option 1: Load with PeftModel
|
||||
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
model = PeftModel.from_pretrained(base_model, "./lora-llama-adapter")
|
||||
|
||||
# Option 2: Load directly (recommended)
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
"./lora-llama-adapter",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Merge adapter into base model
|
||||
|
||||
```python
|
||||
# Merge for deployment (no adapter overhead)
|
||||
merged_model = model.merge_and_unload()
|
||||
|
||||
# Save merged model
|
||||
merged_model.save_pretrained("./llama-merged")
|
||||
tokenizer.save_pretrained("./llama-merged")
|
||||
|
||||
# Push to Hub
|
||||
merged_model.push_to_hub("username/llama-finetuned")
|
||||
```
|
||||
|
||||
### Multi-adapter serving
|
||||
|
||||
```python
|
||||
from peft import PeftModel
|
||||
|
||||
# Load base with first adapter
|
||||
model = AutoPeftModelForCausalLM.from_pretrained("./adapter-task1")
|
||||
|
||||
# Load additional adapters
|
||||
model.load_adapter("./adapter-task2", adapter_name="task2")
|
||||
model.load_adapter("./adapter-task3", adapter_name="task3")
|
||||
|
||||
# Switch between adapters at runtime
|
||||
model.set_adapter("task1") # Use task1 adapter
|
||||
output1 = model.generate(**inputs)
|
||||
|
||||
model.set_adapter("task2") # Switch to task2
|
||||
output2 = model.generate(**inputs)
|
||||
|
||||
# Disable adapters (use base model)
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
```
|
||||
|
||||
## PEFT methods comparison
|
||||
|
||||
| Method | Trainable % | Memory | Speed | Best For |
|
||||
|--------|------------|--------|-------|----------|
|
||||
| **LoRA** | 0.1-1% | Low | Fast | General fine-tuning |
|
||||
| **QLoRA** | 0.1-1% | Very Low | Medium | Memory-constrained |
|
||||
| AdaLoRA | 0.1-1% | Low | Medium | Automatic rank selection |
|
||||
| IA3 | 0.01% | Minimal | Fastest | Few-shot adaptation |
|
||||
| Prefix Tuning | 0.1% | Low | Medium | Generation control |
|
||||
| Prompt Tuning | 0.001% | Minimal | Fast | Simple task adaptation |
|
||||
| P-Tuning v2 | 0.1% | Low | Medium | NLU tasks |
|
||||
|
||||
### IA3 (minimal parameters)
|
||||
|
||||
```python
|
||||
from peft import IA3Config
|
||||
|
||||
ia3_config = IA3Config(
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "down_proj"],
|
||||
feedforward_modules=["down_proj"]
|
||||
)
|
||||
model = get_peft_model(model, ia3_config)
|
||||
# Trains only 0.01% of parameters!
|
||||
```
|
||||
|
||||
### Prefix Tuning
|
||||
|
||||
```python
|
||||
from peft import PrefixTuningConfig
|
||||
|
||||
prefix_config = PrefixTuningConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
num_virtual_tokens=20, # Prepended tokens
|
||||
prefix_projection=True # Use MLP projection
|
||||
)
|
||||
model = get_peft_model(model, prefix_config)
|
||||
```
|
||||
|
||||
## Integration patterns
|
||||
|
||||
### With TRL (SFTTrainer)
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules="all-linear")
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=SFTConfig(output_dir="./output", max_seq_length=512),
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config, # Pass LoRA config directly
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### With Axolotl (YAML config)
|
||||
|
||||
```yaml
|
||||
# axolotl config.yaml
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
lora_target_linear: true # Target all linear layers
|
||||
```
|
||||
|
||||
### With vLLM (inference)
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
# Load base model with LoRA support
|
||||
llm = LLM(model="meta-llama/Llama-3.1-8B", enable_lora=True)
|
||||
|
||||
# Serve with adapter
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
lora_request=LoRARequest("adapter1", 1, "./lora-adapter")
|
||||
)
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Memory usage (Llama 3.1 8B)
|
||||
|
||||
| Method | GPU Memory | Trainable Params |
|
||||
|--------|-----------|------------------|
|
||||
| Full fine-tuning | 60+ GB | 8B (100%) |
|
||||
| LoRA r=16 | 18 GB | 14M (0.17%) |
|
||||
| QLoRA r=16 | 6 GB | 14M (0.17%) |
|
||||
| IA3 | 16 GB | 800K (0.01%) |
|
||||
|
||||
### Training speed (A100 80GB)
|
||||
|
||||
| Method | Tokens/sec | vs Full FT |
|
||||
|--------|-----------|------------|
|
||||
| Full FT | 2,500 | 1x |
|
||||
| LoRA | 3,200 | 1.3x |
|
||||
| QLoRA | 2,100 | 0.84x |
|
||||
|
||||
### Quality (MMLU benchmark)
|
||||
|
||||
| Model | Full FT | LoRA | QLoRA |
|
||||
|-------|---------|------|-------|
|
||||
| Llama 2-7B | 45.3 | 44.8 | 44.1 |
|
||||
| Llama 2-13B | 54.8 | 54.2 | 53.5 |
|
||||
|
||||
## Common issues
|
||||
|
||||
### CUDA OOM during training
|
||||
|
||||
```python
|
||||
# Solution 1: Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Solution 2: Reduce batch size + increase accumulation
|
||||
TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=16
|
||||
)
|
||||
|
||||
# Solution 3: Use QLoRA
|
||||
from transformers import BitsAndBytesConfig
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
|
||||
```
|
||||
|
||||
### Adapter not applying
|
||||
|
||||
```python
|
||||
# Verify adapter is active
|
||||
print(model.active_adapters) # Should show adapter name
|
||||
|
||||
# Check trainable parameters
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# Ensure model in training mode
|
||||
model.train()
|
||||
```
|
||||
|
||||
### Quality degradation
|
||||
|
||||
```python
|
||||
# Increase rank
|
||||
LoraConfig(r=32, lora_alpha=64)
|
||||
|
||||
# Target more modules
|
||||
target_modules = "all-linear"
|
||||
|
||||
# Use more training data and epochs
|
||||
TrainingArguments(num_train_epochs=5)
|
||||
|
||||
# Lower learning rate
|
||||
TrainingArguments(learning_rate=1e-4)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with r=8-16**, increase if quality insufficient
|
||||
2. **Use alpha = 2 * rank** as starting point
|
||||
3. **Target attention + MLP layers** for best quality/efficiency
|
||||
4. **Enable gradient checkpointing** for memory savings
|
||||
5. **Save adapters frequently** (small files, easy rollback)
|
||||
6. **Evaluate on held-out data** before merging
|
||||
7. **Use QLoRA for 70B+ models** on consumer hardware
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - DoRA, LoftQ, rank stabilization, custom modules
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common errors, debugging, optimization
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/huggingface/peft
|
||||
- **Docs**: https://huggingface.co/docs/peft
|
||||
- **LoRA Paper**: arXiv:2106.09685
|
||||
- **QLoRA Paper**: arXiv:2305.14314
|
||||
- **Models**: https://huggingface.co/models?library=peft
|
||||
|
|
@ -1,514 +0,0 @@
|
|||
# PEFT Advanced Usage Guide
|
||||
|
||||
## Advanced LoRA Variants
|
||||
|
||||
### DoRA (Weight-Decomposed Low-Rank Adaptation)
|
||||
|
||||
DoRA decomposes weights into magnitude and direction components, often achieving better results than standard LoRA:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
dora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
use_dora=True, # Enable DoRA
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
model = get_peft_model(model, dora_config)
|
||||
```
|
||||
|
||||
**When to use DoRA**:
|
||||
- Consistently outperforms LoRA on instruction-following tasks
|
||||
- Slightly higher memory (~10%) due to magnitude vectors
|
||||
- Best for quality-critical fine-tuning
|
||||
|
||||
### AdaLoRA (Adaptive Rank)
|
||||
|
||||
Automatically adjusts rank per layer based on importance:
|
||||
|
||||
```python
|
||||
from peft import AdaLoraConfig
|
||||
|
||||
adalora_config = AdaLoraConfig(
|
||||
init_r=64, # Initial rank
|
||||
target_r=16, # Target average rank
|
||||
tinit=200, # Warmup steps
|
||||
tfinal=1000, # Final pruning step
|
||||
deltaT=10, # Rank update frequency
|
||||
beta1=0.85,
|
||||
beta2=0.85,
|
||||
orth_reg_weight=0.5, # Orthogonality regularization
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Allocates more rank to important layers
|
||||
- Can reduce total parameters while maintaining quality
|
||||
- Good for exploring optimal rank distribution
|
||||
|
||||
### LoRA+ (Asymmetric Learning Rates)
|
||||
|
||||
Different learning rates for A and B matrices:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
# LoRA+ uses higher LR for B matrix
|
||||
lora_plus_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
use_rslora=True, # Rank-stabilized LoRA (related technique)
|
||||
)
|
||||
|
||||
# Manual implementation of LoRA+
|
||||
from torch.optim import AdamW
|
||||
|
||||
# Group parameters
|
||||
lora_A_params = [p for n, p in model.named_parameters() if "lora_A" in n]
|
||||
lora_B_params = [p for n, p in model.named_parameters() if "lora_B" in n]
|
||||
|
||||
optimizer = AdamW([
|
||||
{"params": lora_A_params, "lr": 1e-4},
|
||||
{"params": lora_B_params, "lr": 1e-3}, # 10x higher for B
|
||||
])
|
||||
```
|
||||
|
||||
### rsLoRA (Rank-Stabilized LoRA)
|
||||
|
||||
Scales LoRA outputs to stabilize training with different ranks:
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=64,
|
||||
use_rslora=True, # Enables rank-stabilized scaling
|
||||
target_modules="all-linear"
|
||||
)
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- When experimenting with different ranks
|
||||
- Helps maintain consistent behavior across rank values
|
||||
- Recommended for r > 32
|
||||
|
||||
## LoftQ (LoRA-Fine-Tuning-aware Quantization)
|
||||
|
||||
Initializes LoRA weights to compensate for quantization error:
|
||||
|
||||
```python
|
||||
from peft import LoftQConfig, LoraConfig, get_peft_model
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
# LoftQ configuration
|
||||
loftq_config = LoftQConfig(
|
||||
loftq_bits=4, # Quantization bits
|
||||
loftq_iter=5, # Alternating optimization iterations
|
||||
)
|
||||
|
||||
# LoRA config with LoftQ initialization
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
init_lora_weights="loftq",
|
||||
loftq_config=loftq_config,
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
# Load quantized model
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
**Benefits over standard QLoRA**:
|
||||
- Better initial quality after quantization
|
||||
- Faster convergence
|
||||
- ~1-2% better final accuracy on benchmarks
|
||||
|
||||
## Custom Module Targeting
|
||||
|
||||
### Target specific layers
|
||||
|
||||
```python
|
||||
# Target only first and last transformer layers
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.v_proj",
|
||||
"model.layers.31.self_attn.q_proj",
|
||||
"model.layers.31.self_attn.v_proj"],
|
||||
layers_to_transform=[0, 31] # Alternative approach
|
||||
)
|
||||
```
|
||||
|
||||
### Layer pattern matching
|
||||
|
||||
```python
|
||||
# Target layers 0-10 only
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
layers_to_transform=list(range(11)), # Layers 0-10
|
||||
layers_pattern="model.layers"
|
||||
)
|
||||
```
|
||||
|
||||
### Exclude specific layers
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
target_modules="all-linear",
|
||||
modules_to_save=["lm_head"], # Train these fully (not LoRA)
|
||||
)
|
||||
```
|
||||
|
||||
## Embedding and LM Head Training
|
||||
|
||||
### Train embeddings with LoRA
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
# Include embeddings
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "v_proj", "embed_tokens"], # Include embeddings
|
||||
modules_to_save=["lm_head"], # Train lm_head fully
|
||||
)
|
||||
```
|
||||
|
||||
### Extending vocabulary with LoRA
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import get_peft_model, LoraConfig
|
||||
|
||||
# Add new tokens
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
new_tokens = ["<custom_token_1>", "<custom_token_2>"]
|
||||
tokenizer.add_tokens(new_tokens)
|
||||
|
||||
# Resize model embeddings
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Configure LoRA to train new embeddings
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
target_modules="all-linear",
|
||||
modules_to_save=["embed_tokens", "lm_head"], # Train these fully
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
## Multi-Adapter Patterns
|
||||
|
||||
### Adapter composition
|
||||
|
||||
```python
|
||||
from peft import PeftModel
|
||||
|
||||
# Load model with multiple adapters
|
||||
model = AutoPeftModelForCausalLM.from_pretrained("./base-adapter")
|
||||
model.load_adapter("./style-adapter", adapter_name="style")
|
||||
model.load_adapter("./task-adapter", adapter_name="task")
|
||||
|
||||
# Combine adapters (weighted sum)
|
||||
model.add_weighted_adapter(
|
||||
adapters=["style", "task"],
|
||||
weights=[0.7, 0.3],
|
||||
adapter_name="combined",
|
||||
combination_type="linear" # or "cat", "svd"
|
||||
)
|
||||
|
||||
model.set_adapter("combined")
|
||||
```
|
||||
|
||||
### Adapter stacking
|
||||
|
||||
```python
|
||||
# Stack adapters (apply sequentially)
|
||||
model.add_weighted_adapter(
|
||||
adapters=["base", "domain", "task"],
|
||||
weights=[1.0, 1.0, 1.0],
|
||||
adapter_name="stacked",
|
||||
combination_type="cat" # Concatenate adapter outputs
|
||||
)
|
||||
```
|
||||
|
||||
### Dynamic adapter switching
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class MultiAdapterModel:
|
||||
def __init__(self, base_model_path, adapter_paths):
|
||||
self.model = AutoPeftModelForCausalLM.from_pretrained(adapter_paths[0])
|
||||
for name, path in adapter_paths[1:].items():
|
||||
self.model.load_adapter(path, adapter_name=name)
|
||||
|
||||
def generate(self, prompt, adapter_name="default"):
|
||||
self.model.set_adapter(adapter_name)
|
||||
return self.model.generate(**self.tokenize(prompt))
|
||||
|
||||
def generate_ensemble(self, prompt, adapters, weights):
|
||||
"""Generate with weighted adapter ensemble"""
|
||||
outputs = []
|
||||
for adapter, weight in zip(adapters, weights):
|
||||
self.model.set_adapter(adapter)
|
||||
logits = self.model(**self.tokenize(prompt)).logits
|
||||
outputs.append(weight * logits)
|
||||
return torch.stack(outputs).sum(dim=0)
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Gradient checkpointing with LoRA
|
||||
|
||||
```python
|
||||
from peft import prepare_model_for_kbit_training
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model = prepare_model_for_kbit_training(
|
||||
model,
|
||||
use_gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
```
|
||||
|
||||
### CPU offloading for training
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="bf16",
|
||||
gradient_accumulation_steps=8,
|
||||
cpu_offload=True # Offload optimizer states to CPU
|
||||
)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
### Memory-efficient attention with LoRA
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Combine Flash Attention 2 with LoRA
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
## Inference Optimization
|
||||
|
||||
### Merge for deployment
|
||||
|
||||
```python
|
||||
# Merge adapter weights into base model
|
||||
merged_model = model.merge_and_unload()
|
||||
|
||||
# Quantize merged model for inference
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"./merged-model",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
```
|
||||
|
||||
### Export to different formats
|
||||
|
||||
```python
|
||||
# Export to GGUF (llama.cpp)
|
||||
# First merge, then convert
|
||||
merged_model.save_pretrained("./merged-model")
|
||||
|
||||
# Use llama.cpp converter
|
||||
# python convert-hf-to-gguf.py ./merged-model --outfile model.gguf
|
||||
|
||||
# Export to ONNX
|
||||
from optimum.onnxruntime import ORTModelForCausalLM
|
||||
|
||||
ort_model = ORTModelForCausalLM.from_pretrained(
|
||||
"./merged-model",
|
||||
export=True
|
||||
)
|
||||
ort_model.save_pretrained("./onnx-model")
|
||||
```
|
||||
|
||||
### Batch adapter inference
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
# Initialize with LoRA support
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B",
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_loras=4 # Max concurrent adapters
|
||||
)
|
||||
|
||||
# Batch with different adapters
|
||||
requests = [
|
||||
("prompt1", LoRARequest("adapter1", 1, "./adapter1")),
|
||||
("prompt2", LoRARequest("adapter2", 2, "./adapter2")),
|
||||
("prompt3", LoRARequest("adapter1", 1, "./adapter1")),
|
||||
]
|
||||
|
||||
outputs = llm.generate(
|
||||
[r[0] for r in requests],
|
||||
lora_request=[r[1] for r in requests]
|
||||
)
|
||||
```
|
||||
|
||||
## Training Recipes
|
||||
|
||||
### Instruction tuning recipe
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
target_modules="all-linear",
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./output",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.03,
|
||||
bf16=True,
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=100,
|
||||
eval_strategy="steps",
|
||||
eval_steps=100,
|
||||
)
|
||||
```
|
||||
|
||||
### Code generation recipe
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=32, # Higher rank for code
|
||||
lora_alpha=64,
|
||||
lora_dropout=0.1,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
learning_rate=1e-4, # Lower LR for code
|
||||
num_train_epochs=2,
|
||||
max_seq_length=2048, # Longer sequences
|
||||
)
|
||||
```
|
||||
|
||||
### Conversational/Chat recipe
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=16, # alpha = r for chat
|
||||
lora_dropout=0.05,
|
||||
target_modules="all-linear"
|
||||
)
|
||||
|
||||
# Use chat template
|
||||
def format_chat(example):
|
||||
messages = [
|
||||
{"role": "user", "content": example["instruction"]},
|
||||
{"role": "assistant", "content": example["response"]}
|
||||
]
|
||||
return tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
peft_config=lora_config,
|
||||
train_dataset=dataset.map(format_chat),
|
||||
max_seq_length=1024,
|
||||
)
|
||||
```
|
||||
|
||||
## Debugging and Validation
|
||||
|
||||
### Verify adapter application
|
||||
|
||||
```python
|
||||
# Check which modules have LoRA
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "lora_A"):
|
||||
print(f"LoRA applied to: {name}")
|
||||
|
||||
# Print detailed config
|
||||
print(model.peft_config)
|
||||
|
||||
# Check adapter state
|
||||
print(f"Active adapters: {model.active_adapters}")
|
||||
print(f"Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
```
|
||||
|
||||
### Compare with base model
|
||||
|
||||
```python
|
||||
# Generate with adapter
|
||||
model.set_adapter("default")
|
||||
adapter_output = model.generate(**inputs)
|
||||
|
||||
# Generate without adapter
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
|
||||
print(f"Adapter: {tokenizer.decode(adapter_output[0])}")
|
||||
print(f"Base: {tokenizer.decode(base_output[0])}")
|
||||
```
|
||||
|
||||
### Monitor training metrics
|
||||
|
||||
```python
|
||||
from transformers import TrainerCallback
|
||||
|
||||
class LoRACallback(TrainerCallback):
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if "loss" in logs:
|
||||
# Log adapter-specific metrics
|
||||
model = kwargs["model"]
|
||||
lora_params = sum(p.numel() for n, p in model.named_parameters()
|
||||
if "lora" in n and p.requires_grad)
|
||||
print(f"Step {state.global_step}: loss={logs['loss']:.4f}, lora_params={lora_params}")
|
||||
```
|
||||
|
|
@ -1,480 +0,0 @@
|
|||
# PEFT Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### bitsandbytes CUDA Error
|
||||
|
||||
**Error**: `CUDA Setup failed despite GPU being available`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
|
||||
# Install matching bitsandbytes
|
||||
pip uninstall bitsandbytes
|
||||
pip install bitsandbytes --no-cache-dir
|
||||
|
||||
# Or compile from source for specific CUDA
|
||||
git clone https://github.com/TimDettmers/bitsandbytes.git
|
||||
cd bitsandbytes
|
||||
CUDA_VERSION=118 make cuda11x # Adjust for your CUDA
|
||||
pip install .
|
||||
```
|
||||
|
||||
### Triton Import Error
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'triton'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install triton (Linux only)
|
||||
pip install triton
|
||||
|
||||
# Windows: Triton not supported, use CUDA backend
|
||||
# Set environment variable to disable triton
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
```
|
||||
|
||||
### PEFT Version Conflicts
|
||||
|
||||
**Error**: `AttributeError: 'LoraConfig' object has no attribute 'use_dora'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Upgrade to latest PEFT
|
||||
pip install peft>=0.13.0 --upgrade
|
||||
|
||||
# Check version
|
||||
python -c "import peft; print(peft.__version__)"
|
||||
```
|
||||
|
||||
## Training Issues
|
||||
|
||||
### CUDA Out of Memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable gradient checkpointing**:
|
||||
```python
|
||||
from peft import prepare_model_for_kbit_training
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
||||
```
|
||||
|
||||
2. **Reduce batch size**:
|
||||
```python
|
||||
TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=16 # Maintain effective batch size
|
||||
)
|
||||
```
|
||||
|
||||
3. **Use QLoRA**:
|
||||
```python
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
|
||||
```
|
||||
|
||||
4. **Lower LoRA rank**:
|
||||
```python
|
||||
LoraConfig(r=8) # Instead of r=16 or higher
|
||||
```
|
||||
|
||||
5. **Target fewer modules**:
|
||||
```python
|
||||
target_modules=["q_proj", "v_proj"] # Instead of all-linear
|
||||
```
|
||||
|
||||
### Loss Not Decreasing
|
||||
|
||||
**Problem**: Training loss stays flat or increases.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check learning rate**:
|
||||
```python
|
||||
# Start lower
|
||||
TrainingArguments(learning_rate=1e-4) # Not 2e-4 or higher
|
||||
```
|
||||
|
||||
2. **Verify adapter is active**:
|
||||
```python
|
||||
model.print_trainable_parameters()
|
||||
# Should show >0 trainable params
|
||||
|
||||
# Check adapter applied
|
||||
print(model.peft_config)
|
||||
```
|
||||
|
||||
3. **Check data formatting**:
|
||||
```python
|
||||
# Verify tokenization
|
||||
sample = dataset[0]
|
||||
decoded = tokenizer.decode(sample["input_ids"])
|
||||
print(decoded) # Should look correct
|
||||
```
|
||||
|
||||
4. **Increase rank**:
|
||||
```python
|
||||
LoraConfig(r=32, lora_alpha=64) # More capacity
|
||||
```
|
||||
|
||||
### NaN Loss
|
||||
|
||||
**Error**: `Loss is NaN`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Use bf16 instead of fp16
|
||||
TrainingArguments(bf16=True, fp16=False)
|
||||
|
||||
# Or enable loss scaling
|
||||
TrainingArguments(fp16=True, fp16_full_eval=True)
|
||||
|
||||
# Lower learning rate
|
||||
TrainingArguments(learning_rate=5e-5)
|
||||
|
||||
# Check for data issues
|
||||
for batch in dataloader:
|
||||
if torch.isnan(batch["input_ids"].float()).any():
|
||||
print("NaN in input!")
|
||||
```
|
||||
|
||||
### Adapter Not Training
|
||||
|
||||
**Problem**: `trainable params: 0` or model not updating.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Verify LoRA applied to correct modules
|
||||
for name, module in model.named_modules():
|
||||
if "lora" in name.lower():
|
||||
print(f"Found LoRA: {name}")
|
||||
|
||||
# Check target_modules match model architecture
|
||||
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
||||
print(TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.get(model.config.model_type))
|
||||
|
||||
# Ensure model in training mode
|
||||
model.train()
|
||||
|
||||
# Check requires_grad
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
print(f"Trainable: {name}")
|
||||
```
|
||||
|
||||
## Loading Issues
|
||||
|
||||
### Adapter Loading Fails
|
||||
|
||||
**Error**: `ValueError: Can't find adapter weights`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check adapter files exist
|
||||
import os
|
||||
print(os.listdir("./adapter-path"))
|
||||
# Should contain: adapter_config.json, adapter_model.safetensors
|
||||
|
||||
# Load with correct structure
|
||||
from peft import PeftModel, PeftConfig
|
||||
|
||||
# Check config
|
||||
config = PeftConfig.from_pretrained("./adapter-path")
|
||||
print(config)
|
||||
|
||||
# Load base model first
|
||||
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
|
||||
model = PeftModel.from_pretrained(base_model, "./adapter-path")
|
||||
```
|
||||
|
||||
### Base Model Mismatch
|
||||
|
||||
**Error**: `RuntimeError: size mismatch`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure base model matches adapter
|
||||
from peft import PeftConfig
|
||||
|
||||
config = PeftConfig.from_pretrained("./adapter-path")
|
||||
print(f"Base model: {config.base_model_name_or_path}")
|
||||
|
||||
# Load exact same base model
|
||||
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
|
||||
```
|
||||
|
||||
### Safetensors vs PyTorch Format
|
||||
|
||||
**Error**: `ValueError: We couldn't connect to 'https://huggingface.co'`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Force local loading
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
"./adapter-path",
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# Or specify format
|
||||
model.save_pretrained("./adapter", safe_serialization=True) # safetensors
|
||||
model.save_pretrained("./adapter", safe_serialization=False) # pytorch
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Slow Generation
|
||||
|
||||
**Problem**: Inference much slower than expected.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Merge adapter for deployment**:
|
||||
```python
|
||||
merged_model = model.merge_and_unload()
|
||||
# No adapter overhead during inference
|
||||
```
|
||||
|
||||
2. **Use optimized inference engine**:
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(model="./merged-model", dtype="half")
|
||||
```
|
||||
|
||||
3. **Enable Flash Attention**:
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
```
|
||||
|
||||
### Output Quality Issues
|
||||
|
||||
**Problem**: Fine-tuned model produces worse outputs.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check evaluation without adapter**:
|
||||
```python
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
# Compare with adapter output
|
||||
```
|
||||
|
||||
2. **Lower temperature during eval**:
|
||||
```python
|
||||
model.generate(**inputs, temperature=0.1, do_sample=False)
|
||||
```
|
||||
|
||||
3. **Retrain with more data**:
|
||||
```python
|
||||
# Increase training samples
|
||||
# Use higher quality data
|
||||
# Train for more epochs
|
||||
```
|
||||
|
||||
### Wrong Adapter Active
|
||||
|
||||
**Problem**: Model using wrong adapter or no adapter.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check active adapters
|
||||
print(model.active_adapters)
|
||||
|
||||
# Explicitly set adapter
|
||||
model.set_adapter("your-adapter-name")
|
||||
|
||||
# List all adapters
|
||||
print(model.peft_config.keys())
|
||||
```
|
||||
|
||||
## QLoRA Specific Issues
|
||||
|
||||
### Quantization Errors
|
||||
|
||||
**Error**: `RuntimeError: mat1 and mat2 shapes cannot be multiplied`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure compute dtype matches
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16, # Match model dtype
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
|
||||
# Load with correct dtype
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
quantization_config=bnb_config,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
### QLoRA OOM
|
||||
|
||||
**Error**: OOM even with 4-bit quantization.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Enable double quantization
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True # Further memory reduction
|
||||
)
|
||||
|
||||
# Use offloading
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
max_memory={0: "20GB", "cpu": "100GB"}
|
||||
)
|
||||
```
|
||||
|
||||
### QLoRA Merge Fails
|
||||
|
||||
**Error**: `RuntimeError: expected scalar type BFloat16 but found Float`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Dequantize before merging
|
||||
from peft import PeftModel
|
||||
|
||||
# Load in higher precision for merging
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model_name,
|
||||
torch_dtype=torch.float16, # Not quantized
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
# Load adapter
|
||||
model = PeftModel.from_pretrained(base_model, "./qlora-adapter")
|
||||
|
||||
# Now merge
|
||||
merged = model.merge_and_unload()
|
||||
```
|
||||
|
||||
## Multi-Adapter Issues
|
||||
|
||||
### Adapter Conflict
|
||||
|
||||
**Error**: `ValueError: Adapter with name 'default' already exists`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Use unique names
|
||||
model.load_adapter("./adapter1", adapter_name="task1")
|
||||
model.load_adapter("./adapter2", adapter_name="task2")
|
||||
|
||||
# Or delete existing
|
||||
model.delete_adapter("default")
|
||||
```
|
||||
|
||||
### Mixed Precision Adapters
|
||||
|
||||
**Error**: Adapters trained with different dtypes.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Convert adapter precision
|
||||
model = PeftModel.from_pretrained(base_model, "./adapter")
|
||||
model = model.to(torch.bfloat16)
|
||||
|
||||
# Or load with specific dtype
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
"./adapter",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Memory Profiling
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
def print_memory():
|
||||
if torch.cuda.is_available():
|
||||
allocated = torch.cuda.memory_allocated() / 1e9
|
||||
reserved = torch.cuda.memory_reserved() / 1e9
|
||||
print(f"Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
|
||||
|
||||
# Profile during training
|
||||
print_memory() # Before
|
||||
model.train()
|
||||
loss = model(**batch).loss
|
||||
loss.backward()
|
||||
print_memory() # After
|
||||
```
|
||||
|
||||
### Speed Profiling
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
|
||||
def benchmark_generation(model, tokenizer, prompt, n_runs=5):
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# Warmup
|
||||
model.generate(**inputs, max_new_tokens=10)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
start = time.perf_counter()
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
torch.cuda.synchronize()
|
||||
times.append(time.perf_counter() - start)
|
||||
|
||||
tokens = outputs.shape[1] - inputs.input_ids.shape[1]
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"Speed: {tokens/avg_time:.2f} tokens/sec")
|
||||
|
||||
# Compare adapter vs merged
|
||||
benchmark_generation(adapter_model, tokenizer, "Hello")
|
||||
benchmark_generation(merged_model, tokenizer, "Hello")
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **Check PEFT GitHub Issues**: https://github.com/huggingface/peft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co/
|
||||
3. **PEFT Documentation**: https://huggingface.co/docs/peft
|
||||
|
||||
### Debugging Template
|
||||
|
||||
When reporting issues, include:
|
||||
|
||||
```python
|
||||
# System info
|
||||
import peft
|
||||
import transformers
|
||||
import torch
|
||||
|
||||
print(f"PEFT: {peft.__version__}")
|
||||
print(f"Transformers: {transformers.__version__}")
|
||||
print(f"PyTorch: {torch.__version__}")
|
||||
print(f"CUDA: {torch.version.cuda}")
|
||||
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
|
||||
|
||||
# Config
|
||||
print(model.peft_config)
|
||||
model.print_trainable_parameters()
|
||||
```
|
||||
|
|
@ -1,467 +0,0 @@
|
|||
---
|
||||
name: slime-rl-training
|
||||
description: Provides guidance for LLM post-training with RL using slime, a Megatron+SGLang framework. Use when training GLM models, implementing custom data generation workflows, or needing tight Megatron-LM integration for RL scaling.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [sglang-router>=0.2.3, ray, torch>=2.0.0, transformers>=4.40.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Reinforcement Learning, Megatron-LM, SGLang, GRPO, Post-Training, GLM]
|
||||
|
||||
---
|
||||
|
||||
# slime: LLM Post-Training Framework for RL Scaling
|
||||
|
||||
slime is an LLM post-training framework from Tsinghua's THUDM team, powering GLM-4.5, GLM-4.6, and GLM-4.7. It connects Megatron-LM for training with SGLang for high-throughput rollout generation.
|
||||
|
||||
## When to Use slime
|
||||
|
||||
**Choose slime when you need:**
|
||||
- Megatron-LM native training with SGLang inference
|
||||
- Custom data generation workflows with flexible data buffers
|
||||
- Training GLM, Qwen3, DeepSeek V3, or Llama 3 models
|
||||
- Research-grade framework with production backing (Z.ai)
|
||||
|
||||
**Consider alternatives when:**
|
||||
- You need enterprise-grade stability features → use **miles**
|
||||
- You want flexible backend swapping → use **verl**
|
||||
- You need PyTorch-native abstractions → use **torchforge**
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Training**: Megatron-LM with full parallelism support (TP, PP, DP, SP)
|
||||
- **Rollout**: SGLang-based high-throughput generation with router
|
||||
- **Data Buffer**: Flexible prompt management and sample storage
|
||||
- **Models**: GLM-4.x, Qwen3, DeepSeek V3/R1, Llama 3
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Data Buffer │
|
||||
│ - Prompt initialization and management │
|
||||
│ - Custom data generation and filtering │
|
||||
│ - Rollout sample storage │
|
||||
└─────────────┬───────────────────────────┬───────────────┘
|
||||
│ │
|
||||
┌─────────────▼───────────┐ ┌─────────────▼───────────────┐
|
||||
│ Training (Megatron-LM) │ │ Rollout (SGLang + Router) │
|
||||
│ - Actor model training │ │ - Response generation │
|
||||
│ - Critic (optional) │ │ - Reward/verifier output │
|
||||
│ - Weight sync to rollout│ │ - Multi-turn support │
|
||||
└─────────────────────────┘ └─────────────────────────────┘
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Recommended: Docker
|
||||
docker pull slimerl/slime:latest
|
||||
docker run --rm --gpus all --ipc=host --shm-size=16g \
|
||||
-it slimerl/slime:latest /bin/bash
|
||||
|
||||
# Inside container
|
||||
cd /root/slime && pip install -e . --no-deps
|
||||
```
|
||||
|
||||
### From Source
|
||||
|
||||
```bash
|
||||
git clone https://github.com/THUDM/slime.git
|
||||
cd slime
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Quick Start: GRPO Training
|
||||
|
||||
```bash
|
||||
# Source model configuration
|
||||
source scripts/models/qwen3-4B.sh
|
||||
|
||||
# Launch training
|
||||
python train.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 4 \
|
||||
--rollout-num-gpus 4 \
|
||||
--advantage-estimator grpo \
|
||||
--use-kl-loss --kl-loss-coef 0.001 \
|
||||
--rollout-batch-size 32 \
|
||||
--n-samples-per-prompt 8 \
|
||||
--global-batch-size 256 \
|
||||
--num-rollout 3000 \
|
||||
--prompt-data /path/to/data.jsonl \
|
||||
${MODEL_ARGS[@]} ${CKPT_ARGS[@]}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Workflow 1: Standard GRPO Training
|
||||
|
||||
Use this workflow for training reasoning models with group-relative advantages.
|
||||
|
||||
### Prerequisites Checklist
|
||||
- [ ] Docker environment or Megatron-LM + SGLang installed
|
||||
- [ ] Model checkpoint (HuggingFace or Megatron format)
|
||||
- [ ] Training data in JSONL format
|
||||
|
||||
### Step 1: Prepare Data
|
||||
|
||||
```python
|
||||
# data.jsonl format
|
||||
{"prompt": "What is 2 + 2?", "label": "4"}
|
||||
{"prompt": "Solve: 3x = 12", "label": "x = 4"}
|
||||
```
|
||||
|
||||
Or with chat format:
|
||||
```python
|
||||
{
|
||||
"prompt": [
|
||||
{"role": "system", "content": "You are a math tutor."},
|
||||
{"role": "user", "content": "What is 15 + 27?"}
|
||||
],
|
||||
"label": "42"
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Configure Model
|
||||
|
||||
Choose a pre-configured model script:
|
||||
|
||||
```bash
|
||||
# List available models
|
||||
ls scripts/models/
|
||||
# glm4-9B.sh, qwen3-4B.sh, qwen3-30B-A3B.sh, deepseek-v3.sh, llama3-8B.sh, ...
|
||||
|
||||
# Source your model
|
||||
source scripts/models/qwen3-4B.sh
|
||||
```
|
||||
|
||||
### Step 3: Launch Training
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--advantage-estimator grpo \
|
||||
--use-kl-loss \
|
||||
--kl-loss-coef 0.001 \
|
||||
--prompt-data /path/to/train.jsonl \
|
||||
--input-key prompt \
|
||||
--label-key label \
|
||||
--apply-chat-template \
|
||||
--rollout-batch-size 32 \
|
||||
--n-samples-per-prompt 8 \
|
||||
--global-batch-size 256 \
|
||||
--num-rollout 3000 \
|
||||
--save-interval 100 \
|
||||
--eval-interval 50 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Step 4: Monitor Training
|
||||
- [ ] Check TensorBoard: `tensorboard --logdir outputs/`
|
||||
- [ ] Verify reward curves are increasing
|
||||
- [ ] Monitor GPU utilization across nodes
|
||||
|
||||
---
|
||||
|
||||
## Workflow 2: Asynchronous Training
|
||||
|
||||
Use async mode for higher throughput by overlapping rollout and training.
|
||||
|
||||
### When to Use Async
|
||||
- Large models with long generation times
|
||||
- High GPU idle time in synchronous mode
|
||||
- Sufficient memory for buffering
|
||||
|
||||
### Launch Async Training
|
||||
|
||||
```bash
|
||||
python train_async.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--advantage-estimator grpo \
|
||||
--async-buffer-size 4 \
|
||||
--prompt-data /path/to/train.jsonl \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Async-Specific Parameters
|
||||
|
||||
```bash
|
||||
--async-buffer-size 4 # Number of rollouts to buffer
|
||||
--update-weights-interval 2 # Sync weights every N rollouts
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Workflow 3: Multi-Turn Agentic Training
|
||||
|
||||
Use this workflow for training agents with tool use or multi-step reasoning.
|
||||
|
||||
### Prerequisites
|
||||
- [ ] Custom generate function for multi-turn logic
|
||||
- [ ] Tool/environment interface
|
||||
|
||||
### Step 1: Define Custom Generate Function
|
||||
|
||||
```python
|
||||
# custom_generate.py
|
||||
async def custom_generate(args, samples, evaluation=False):
|
||||
"""Multi-turn generation with tool calling."""
|
||||
for sample in samples:
|
||||
conversation = sample.prompt
|
||||
|
||||
for turn in range(args.max_turns):
|
||||
# Generate response
|
||||
response = await generate_single(conversation)
|
||||
|
||||
# Check for tool call
|
||||
tool_call = extract_tool_call(response)
|
||||
if tool_call:
|
||||
tool_result = execute_tool(tool_call)
|
||||
conversation.append({"role": "assistant", "content": response})
|
||||
conversation.append({"role": "tool", "content": tool_result})
|
||||
else:
|
||||
break
|
||||
|
||||
sample.response = response
|
||||
sample.reward = compute_reward(sample)
|
||||
|
||||
return samples
|
||||
```
|
||||
|
||||
### Step 2: Launch with Custom Function
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-generate-function-path custom_generate.py \
|
||||
--max-turns 5 \
|
||||
--prompt-data /path/to/agent_data.jsonl \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
See `examples/search-r1/` for a complete multi-turn search example.
|
||||
|
||||
---
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### Three Argument Categories
|
||||
|
||||
slime uses three types of arguments:
|
||||
|
||||
**1. Megatron Arguments** (passed directly):
|
||||
```bash
|
||||
--tensor-model-parallel-size 2
|
||||
--pipeline-model-parallel-size 1
|
||||
--num-layers 32
|
||||
--hidden-size 4096
|
||||
```
|
||||
|
||||
**2. SGLang Arguments** (prefixed with `--sglang-`):
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.8
|
||||
--sglang-context-length 8192
|
||||
--sglang-log-level INFO
|
||||
```
|
||||
|
||||
**3. slime Arguments**:
|
||||
```bash
|
||||
# Resource allocation
|
||||
--actor-num-nodes 1
|
||||
--actor-num-gpus-per-node 8
|
||||
--rollout-num-gpus 8
|
||||
--colocate # Share GPUs between training/inference
|
||||
|
||||
# Data
|
||||
--prompt-data /path/to/data.jsonl
|
||||
--input-key prompt
|
||||
--label-key label
|
||||
|
||||
# Training loop
|
||||
--num-rollout 3000
|
||||
--rollout-batch-size 32
|
||||
--n-samples-per-prompt 8
|
||||
--global-batch-size 256
|
||||
|
||||
# Algorithm
|
||||
--advantage-estimator grpo # or: gspo, ppo, reinforce_plus_plus
|
||||
--use-kl-loss
|
||||
--kl-loss-coef 0.001
|
||||
```
|
||||
|
||||
### Key Constraints
|
||||
|
||||
```
|
||||
rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rollout
|
||||
```
|
||||
|
||||
Example: 32 × 8 = 256 × 1
|
||||
|
||||
---
|
||||
|
||||
## Data Buffer System
|
||||
|
||||
slime's data buffer enables flexible data management:
|
||||
|
||||
### Basic Data Source
|
||||
|
||||
```python
|
||||
class RolloutDataSource:
|
||||
def get_samples(self, num_samples):
|
||||
"""Fetch prompts from dataset."""
|
||||
return self.dataset.sample(num_samples)
|
||||
|
||||
def add_samples(self, samples):
|
||||
"""Called after generation (no-op by default)."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Buffered Data Source (Off-Policy)
|
||||
|
||||
```python
|
||||
class RolloutDataSourceWithBuffer(RolloutDataSource):
|
||||
def __init__(self):
|
||||
self.buffer = []
|
||||
|
||||
def add_samples(self, samples):
|
||||
"""Store generated samples for reuse."""
|
||||
self.buffer.extend(samples)
|
||||
|
||||
def buffer_filter(self, args, buffer, num_samples):
|
||||
"""Custom selection logic (prioritized, stratified, etc.)."""
|
||||
return select_best(buffer, num_samples)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Issue: SGLang Engine Crash
|
||||
|
||||
**Symptoms**: Inference engine dies mid-training
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Enable fault tolerance
|
||||
--use-fault-tolerance
|
||||
|
||||
# Increase memory allocation
|
||||
--sglang-mem-fraction-static 0.85
|
||||
|
||||
# Reduce batch size
|
||||
--rollout-batch-size 16
|
||||
```
|
||||
|
||||
### Issue: Weight Sync Timeout
|
||||
|
||||
**Symptoms**: Training hangs after rollout
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Increase sync interval
|
||||
--update-weights-interval 5
|
||||
|
||||
# Use colocated mode (no network transfer)
|
||||
--colocate
|
||||
```
|
||||
|
||||
### Issue: OOM During Training
|
||||
|
||||
**Symptoms**: CUDA OOM in backward pass
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Enable gradient checkpointing
|
||||
--recompute-activations
|
||||
|
||||
# Reduce micro-batch size
|
||||
--micro-batch-size 1
|
||||
|
||||
# Enable sequence parallelism
|
||||
--sequence-parallel
|
||||
```
|
||||
|
||||
### Issue: Slow Data Loading
|
||||
|
||||
**Symptoms**: GPU idle during data fetch
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Increase data workers
|
||||
--num-data-workers 4
|
||||
|
||||
# Use streaming dataset
|
||||
--streaming-data
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model Family | Configurations |
|
||||
|--------------|----------------|
|
||||
| GLM | GLM-4.5, GLM-4.6, GLM-4.7, GLM-Z1-9B |
|
||||
| Qwen | Qwen3 (4B, 8B, 30B-A3B), Qwen3-MoE, Qwen2.5 |
|
||||
| DeepSeek | V3, V3.1, R1 |
|
||||
| Llama | Llama 3 (8B, 70B) |
|
||||
| Others | Kimi K2, Moonlight-16B |
|
||||
|
||||
Each model has pre-configured scripts in `scripts/models/`.
|
||||
|
||||
---
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Co-location Mode
|
||||
|
||||
Share GPUs between training and inference to reduce memory:
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--colocate \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--sglang-mem-fraction-static 0.4 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Custom Reward Model
|
||||
|
||||
```python
|
||||
# custom_rm.py
|
||||
class CustomRewardModel:
|
||||
def __init__(self, model_path):
|
||||
self.model = load_model(model_path)
|
||||
|
||||
def compute_reward(self, prompts, responses):
|
||||
inputs = self.tokenize(prompts, responses)
|
||||
scores = self.model(inputs)
|
||||
return scores.tolist()
|
||||
```
|
||||
|
||||
```bash
|
||||
--custom-rm-path custom_rm.py
|
||||
```
|
||||
|
||||
### Evaluation Multi-Task
|
||||
|
||||
```bash
|
||||
--eval-prompt-data aime /path/to/aime.jsonl \
|
||||
--eval-prompt-data gsm8k /path/to/gsm8k.jsonl \
|
||||
--n-samples-per-eval-prompt 16
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://thudm.github.io/slime/
|
||||
- **GitHub**: https://github.com/THUDM/slime
|
||||
- **Blog**: https://lmsys.org/blog/2025-07-09-slime/
|
||||
- **Examples**: See `examples/` directory for 14+ worked examples
|
||||
|
||||
|
|
@ -1,392 +0,0 @@
|
|||
# slime API Reference
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
slime operates with a three-module architecture orchestrated by Ray:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Data Buffer │
|
||||
│ - Prompt initialization and management │
|
||||
│ - Custom data generation and filtering │
|
||||
│ - Rollout sample storage │
|
||||
└─────────────┬───────────────────────────┬───────────────┘
|
||||
│ │
|
||||
┌─────────────▼───────────┐ ┌─────────────▼───────────────┐
|
||||
│ Training (Megatron-LM) │ │ Rollout (SGLang + Router) │
|
||||
│ - Actor model training │ │ - Response generation │
|
||||
│ - Critic (optional) │ │ - Reward/verifier output │
|
||||
│ - Weight sync to rollout│ │ - Multi-turn support │
|
||||
└─────────────────────────┘ └─────────────────────────────┘
|
||||
```
|
||||
|
||||
## Core Data Structures
|
||||
|
||||
### Sample Object
|
||||
|
||||
The `Sample` object is the core data structure defined in `slime/utils/types.py`:
|
||||
|
||||
```python
|
||||
from slime.utils.types import Sample
|
||||
|
||||
@dataclass
|
||||
class Sample:
|
||||
# Core fields
|
||||
group_index: Optional[int] # Group index for batching
|
||||
index: Optional[int] # Sample index
|
||||
prompt: str | list[dict] = "" # Input prompt or chat history
|
||||
tokens: list[int] = field(default_factory=list) # Token IDs
|
||||
response: str = "" # Generated response
|
||||
response_length: int = 0 # Response length in tokens
|
||||
label: Optional[str] = None # Ground truth label
|
||||
reward: Optional[float | dict] = None # RL reward signal
|
||||
loss_mask: Optional[list[int]] = None # 1=compute loss, 0=mask
|
||||
status: Status = Status.PENDING # Sample status
|
||||
metadata: dict = field(default_factory=dict) # Custom data
|
||||
|
||||
# Multimodal support
|
||||
multimodal_inputs: Optional[Any] = None # Raw multimodal data (images, videos)
|
||||
multimodal_train_inputs: Optional[Any] = None # Processed multimodal data (pixel_values)
|
||||
|
||||
# Rollout tracking
|
||||
weight_versions: list[str] = field(default_factory=list)
|
||||
rollout_log_probs: Optional[list[float]] = None # Log probs from SGLang
|
||||
rollout_routed_experts: Optional[list[list[int]]] = None # Expert routing (MoE)
|
||||
|
||||
# Control fields
|
||||
remove_sample: bool = False
|
||||
generate_function_path: Optional[str] = None
|
||||
train_metadata: Optional[dict] = None
|
||||
non_generation_time: float = 0.0
|
||||
|
||||
# Speculative decoding info (nested dataclass)
|
||||
@dataclass
|
||||
class SpecInfo:
|
||||
spec_accept_token_num: int = 0
|
||||
spec_draft_token_num: int = 0
|
||||
spec_verify_ct: int = 0
|
||||
completion_token_num: int = 0
|
||||
```
|
||||
|
||||
### Status Enum
|
||||
|
||||
```python
|
||||
class Status(Enum):
|
||||
PENDING = "pending" # Not yet processed
|
||||
COMPLETED = "completed" # Successfully generated
|
||||
TRUNCATED = "truncated" # Hit max length
|
||||
ABORTED = "aborted" # Failed generation
|
||||
FAILED = "failed" # Generation failed
|
||||
```
|
||||
|
||||
## Configuration System
|
||||
|
||||
slime uses three categories of command-line arguments:
|
||||
|
||||
### 1. Megatron Arguments
|
||||
|
||||
All Megatron-LM arguments are supported directly:
|
||||
|
||||
```bash
|
||||
--tensor-model-parallel-size 2
|
||||
--pipeline-model-parallel-size 1
|
||||
--num-layers 32
|
||||
--hidden-size 4096
|
||||
--num-attention-heads 32
|
||||
--seq-length 4096
|
||||
--micro-batch-size 1
|
||||
--global-batch-size 256
|
||||
```
|
||||
|
||||
### 2. SGLang Arguments
|
||||
|
||||
SGLang arguments are prefixed with `--sglang-`:
|
||||
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.8 # GPU memory for KV cache
|
||||
--sglang-context-length 8192 # Maximum context length
|
||||
--sglang-log-level INFO # Logging verbosity
|
||||
--sglang-tp-size 2 # Tensor parallelism
|
||||
--sglang-disable-cuda-graph # Disable CUDA graphs
|
||||
```
|
||||
|
||||
### 3. slime-Specific Arguments
|
||||
|
||||
Defined in `slime/utils/arguments.py`:
|
||||
|
||||
```bash
|
||||
# Resource Allocation
|
||||
--actor-num-nodes 1 # Training nodes
|
||||
--actor-num-gpus-per-node 8 # GPUs per training node
|
||||
--rollout-num-gpus 8 # Total rollout GPUs
|
||||
--rollout-num-gpus-per-engine 2 # GPUs per SGLang engine
|
||||
--colocate # Share GPUs for train/inference
|
||||
|
||||
# Data Configuration
|
||||
--prompt-data /path/to/data.jsonl # Training data path
|
||||
--input-key prompt # Key for prompts in JSON
|
||||
--label-key label # Key for labels in JSON
|
||||
--apply-chat-template # Apply chat formatting
|
||||
|
||||
# Training Loop
|
||||
--num-rollout 3000 # Total rollout iterations
|
||||
--rollout-batch-size 32 # Prompts per rollout
|
||||
--n-samples-per-prompt 8 # Responses per prompt
|
||||
--global-batch-size 256 # Training batch size
|
||||
--num-steps-per-rollout 1 # Training steps per rollout
|
||||
|
||||
# RL Algorithm
|
||||
--advantage-estimator grpo # grpo, gspo, ppo, reinforce_plus_plus
|
||||
--use-kl-loss # Enable KL loss
|
||||
--kl-loss-coef 0.001 # KL coefficient
|
||||
--calculate-per-token-loss # Token-level loss
|
||||
|
||||
# Off-Policy Options
|
||||
--use-tis # Truncated Importance Sampling
|
||||
--tis-threshold 0.9 # TIS threshold
|
||||
--true-on-policy-mode # Force on-policy training
|
||||
```
|
||||
|
||||
## Data Buffer System
|
||||
|
||||
### RolloutDataSource (Base Class)
|
||||
|
||||
```python
|
||||
from slime.data import RolloutDataSource
|
||||
|
||||
class RolloutDataSource:
|
||||
def __init__(self, dataset, args):
|
||||
self.dataset = dataset
|
||||
self.args = args
|
||||
|
||||
def get_samples(self, num_samples: int) -> list[Sample]:
|
||||
"""Fetch prompts from dataset."""
|
||||
return [Sample(prompt=p) for p in self.dataset.sample(num_samples)]
|
||||
|
||||
def add_samples(self, samples: list[Sample]) -> None:
|
||||
"""Called after generation (no-op by default)."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Buffered Data Source (Off-Policy)
|
||||
|
||||
```python
|
||||
from slime.data import RolloutDataSourceWithBuffer
|
||||
|
||||
class RolloutDataSourceWithBuffer(RolloutDataSource):
|
||||
def __init__(self, dataset, args):
|
||||
super().__init__(dataset, args)
|
||||
self.buffer = []
|
||||
|
||||
def add_samples(self, samples: list[Sample]) -> None:
|
||||
"""Store generated samples for reuse."""
|
||||
self.buffer.extend(samples)
|
||||
|
||||
def buffer_filter(self, args, buffer, num_samples) -> list[Sample]:
|
||||
"""Custom selection logic."""
|
||||
# Example: prioritized sampling based on reward
|
||||
sorted_buffer = sorted(buffer, key=lambda s: s.reward, reverse=True)
|
||||
return sorted_buffer[:num_samples]
|
||||
```
|
||||
|
||||
## Custom Functions
|
||||
|
||||
### Custom Generate Function
|
||||
|
||||
For multi-turn or tool-calling scenarios:
|
||||
|
||||
```python
|
||||
# custom_generate.py
|
||||
from slime.data import Sample
|
||||
|
||||
async def custom_generate(args, samples: list[Sample], evaluation: bool = False) -> list[Sample]:
|
||||
"""
|
||||
Custom generation function for multi-turn interactions.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
samples: List of Sample objects with prompts
|
||||
evaluation: Whether this is an evaluation run
|
||||
|
||||
Returns:
|
||||
List of Sample objects with responses and rewards
|
||||
"""
|
||||
for sample in samples:
|
||||
conversation = sample.prompt if isinstance(sample.prompt, list) else [
|
||||
{"role": "user", "content": sample.prompt}
|
||||
]
|
||||
|
||||
for turn in range(args.max_turns):
|
||||
# Generate response
|
||||
response = await generate_single(conversation)
|
||||
|
||||
# Check for tool call
|
||||
tool_call = extract_tool_call(response)
|
||||
if tool_call:
|
||||
# Execute tool
|
||||
tool_result = await execute_tool(tool_call)
|
||||
conversation.append({"role": "assistant", "content": response})
|
||||
conversation.append({"role": "tool", "content": tool_result})
|
||||
else:
|
||||
# Final response
|
||||
sample.response = response
|
||||
break
|
||||
|
||||
# Compute reward
|
||||
sample.reward = compute_reward(sample)
|
||||
|
||||
# Set loss mask (1 for model tokens, 0 for tool responses)
|
||||
sample.loss_mask = build_loss_mask(sample)
|
||||
|
||||
return samples
|
||||
```
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-generate-function-path custom_generate.py \
|
||||
--max-turns 5
|
||||
```
|
||||
|
||||
### Custom Reward Function
|
||||
|
||||
```python
|
||||
# custom_rm.py
|
||||
from slime.data import Sample
|
||||
|
||||
async def reward_func(args, sample: Sample, **kwargs) -> float:
|
||||
"""
|
||||
Compute reward for a single sample.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
sample: Sample object with response
|
||||
|
||||
Returns:
|
||||
Reward score (float)
|
||||
"""
|
||||
response = sample.response
|
||||
ground_truth = sample.label or sample.metadata.get("answer", "")
|
||||
|
||||
# Example: exact match reward
|
||||
if response.strip() == ground_truth.strip():
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
# For batched processing (more efficient)
|
||||
async def batched_custom_rm(args, samples: list[Sample]) -> list[float]:
|
||||
"""Batch reward computation."""
|
||||
rewards = []
|
||||
for sample in samples:
|
||||
reward = await reward_func(args, sample)
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
```
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-rm-path custom_rm.py \
|
||||
--group-rm # Enable batched processing
|
||||
```
|
||||
|
||||
## Model Configuration
|
||||
|
||||
### Pre-configured Model Scripts
|
||||
|
||||
Located in `scripts/models/`:
|
||||
|
||||
```bash
|
||||
# List available models
|
||||
ls scripts/models/
|
||||
# glm4-9B.sh, qwen3-4B.sh, qwen3-30B-A3B.sh, deepseek-v3.sh, llama3-8B.sh
|
||||
|
||||
# Source model configuration
|
||||
source scripts/models/qwen3-4B.sh
|
||||
# This sets MODEL_ARGS and CKPT_ARGS arrays
|
||||
```
|
||||
|
||||
### Example Model Script
|
||||
|
||||
```bash
|
||||
# scripts/models/qwen3-4B.sh
|
||||
export MODEL_ARGS=(
|
||||
--num-layers 36
|
||||
--hidden-size 2560
|
||||
--num-attention-heads 20
|
||||
--num-query-groups 4
|
||||
--ffn-hidden-size 6912
|
||||
--max-position-embeddings 32768
|
||||
--rotary-percent 1.0
|
||||
--rotary-base 1000000
|
||||
--swiglu
|
||||
--untie-embeddings-and-output-weights
|
||||
--no-position-embedding
|
||||
--normalization RMSNorm
|
||||
--tokenizer-type HuggingFaceTokenizer
|
||||
--bf16
|
||||
)
|
||||
|
||||
export CKPT_ARGS=(
|
||||
--hf-checkpoint /path/to/qwen3-4b-hf
|
||||
--initial-megatron-checkpoint /path/to/megatron/ckpt
|
||||
)
|
||||
```
|
||||
|
||||
## Async Training
|
||||
|
||||
### Enabling Async Mode
|
||||
|
||||
```bash
|
||||
python train_async.py \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--async-buffer-size 4 \
|
||||
--update-weights-interval 2 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Async-Specific Parameters
|
||||
|
||||
```bash
|
||||
--async-buffer-size 4 # Number of rollouts to buffer
|
||||
--update-weights-interval 2 # Sync weights every N rollouts
|
||||
```
|
||||
|
||||
**Note**: Colocated mode (`--colocate`) is NOT supported with async training.
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Multi-Task Evaluation
|
||||
|
||||
```bash
|
||||
--eval-prompt-data aime /path/to/aime.jsonl \
|
||||
--eval-prompt-data gsm8k /path/to/gsm8k.jsonl \
|
||||
--n-samples-per-eval-prompt 16 \
|
||||
--eval-interval 50
|
||||
```
|
||||
|
||||
### Evaluation Configuration
|
||||
|
||||
```bash
|
||||
--eval-interval 50 # Evaluate every N rollouts
|
||||
--n-samples-per-eval-prompt 16 # Samples for evaluation
|
||||
--eval-temperature 0.0 # Greedy decoding for eval
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model Family | Configurations |
|
||||
|--------------|----------------|
|
||||
| GLM | GLM-4.5, GLM-4.6, GLM-4.7, GLM-Z1-9B |
|
||||
| Qwen | Qwen3 (4B, 8B, 30B-A3B), Qwen3-MoE, Qwen2.5 |
|
||||
| DeepSeek | V3, V3.1, R1 |
|
||||
| Llama | Llama 3 (8B, 70B) |
|
||||
| Others | Kimi K2, Moonlight-16B |
|
||||
|
||||
## Resources
|
||||
|
||||
- Documentation: https://thudm.github.io/slime/
|
||||
- GitHub: https://github.com/THUDM/slime
|
||||
- Blog: https://lmsys.org/blog/2025-07-09-slime/
|
||||
- Examples: `examples/` directory (14+ worked examples)
|
||||
|
|
@ -1,386 +0,0 @@
|
|||
# slime Troubleshooting Guide
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### SGLang Issues
|
||||
|
||||
#### Issue: SGLang Engine Crash
|
||||
|
||||
**Symptoms**: Inference engine dies mid-training, connection errors
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable fault tolerance**:
|
||||
```bash
|
||||
--use-fault-tolerance
|
||||
```
|
||||
|
||||
2. **Increase memory allocation**:
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.85 # Increase from 0.8
|
||||
```
|
||||
|
||||
3. **Reduce batch size**:
|
||||
```bash
|
||||
--rollout-batch-size 16 # Reduce from 32
|
||||
```
|
||||
|
||||
4. **Disable CUDA graphs** (for debugging):
|
||||
```bash
|
||||
--sglang-disable-cuda-graph
|
||||
```
|
||||
|
||||
#### Issue: SGLang Router Load Imbalance
|
||||
|
||||
**Symptoms**: Some SGLang engines overloaded while others idle
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Adjust routing strategy**:
|
||||
```bash
|
||||
--sglang-router-strategy round_robin
|
||||
```
|
||||
|
||||
2. **Increase number of engines**:
|
||||
```bash
|
||||
--rollout-num-gpus-per-engine 1 # More engines, less GPUs each
|
||||
```
|
||||
|
||||
### Weight Synchronization Issues
|
||||
|
||||
#### Issue: Weight Sync Timeout
|
||||
|
||||
**Symptoms**: Training hangs after rollout, timeout errors
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase sync interval** (async mode):
|
||||
```bash
|
||||
--update-weights-interval 5 # Increase from 2
|
||||
```
|
||||
|
||||
2. **Use colocated mode** (eliminates network transfer):
|
||||
```bash
|
||||
--colocate
|
||||
```
|
||||
|
||||
3. **Check network bandwidth**:
|
||||
```bash
|
||||
# Verify InfiniBand is enabled
|
||||
ibstat
|
||||
```
|
||||
|
||||
#### Issue: Weight Sync Failures in Multi-Node
|
||||
|
||||
**Symptoms**: Nodes fail to receive updated weights
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Set NCCL environment**:
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
export NCCL_IB_DISABLE=0
|
||||
```
|
||||
|
||||
2. **Increase timeout**:
|
||||
```bash
|
||||
export NCCL_TIMEOUT=1800
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
|
||||
#### Issue: OOM During Training
|
||||
|
||||
**Symptoms**: CUDA OOM in backward pass
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable gradient checkpointing**:
|
||||
```bash
|
||||
--recompute-activations
|
||||
```
|
||||
|
||||
2. **Reduce micro-batch size**:
|
||||
```bash
|
||||
--micro-batch-size 1
|
||||
```
|
||||
|
||||
3. **Enable sequence parallelism**:
|
||||
```bash
|
||||
--sequence-parallel
|
||||
```
|
||||
|
||||
4. **Reduce global batch size**:
|
||||
```bash
|
||||
--global-batch-size 128 # Reduce from 256
|
||||
```
|
||||
|
||||
#### Issue: OOM in Colocated Mode
|
||||
|
||||
**Symptoms**: OOM when both training and inference run on same GPUs
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce SGLang memory**:
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.4 # Reduce from 0.8
|
||||
```
|
||||
|
||||
2. **Enable offloading**:
|
||||
```bash
|
||||
--offload-optimizer-states
|
||||
```
|
||||
|
||||
3. **Use smaller sequence length**:
|
||||
```bash
|
||||
--seq-length 2048 # Reduce from 4096
|
||||
```
|
||||
|
||||
### Data Loading Issues
|
||||
|
||||
#### Issue: Slow Data Loading
|
||||
|
||||
**Symptoms**: GPU idle during data fetch, low GPU utilization
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase data workers**:
|
||||
```bash
|
||||
--num-data-workers 4
|
||||
```
|
||||
|
||||
2. **Use streaming dataset**:
|
||||
```bash
|
||||
--streaming-data
|
||||
```
|
||||
|
||||
3. **Pre-tokenize data**:
|
||||
```python
|
||||
# Pre-process data offline
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("model_path")
|
||||
# Save tokenized data
|
||||
```
|
||||
|
||||
#### Issue: Data Format Errors
|
||||
|
||||
**Symptoms**: KeyError, missing fields, parsing failures
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify data format**:
|
||||
```python
|
||||
import json
|
||||
with open("data.jsonl") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
assert "prompt" in data, "Missing prompt field"
|
||||
assert "label" in data, "Missing label field"
|
||||
```
|
||||
|
||||
2. **Check key names**:
|
||||
```bash
|
||||
--input-key prompt # Must match your data
|
||||
--label-key label # Must match your data
|
||||
```
|
||||
|
||||
### Training Stability Issues
|
||||
|
||||
#### Issue: Loss Explosion / NaN
|
||||
|
||||
**Symptoms**: Loss becomes NaN or explodes
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce learning rate**:
|
||||
```bash
|
||||
--lr 1e-6 # Reduce from 5e-6
|
||||
```
|
||||
|
||||
2. **Enable gradient clipping**:
|
||||
```bash
|
||||
--clip-grad 1.0
|
||||
```
|
||||
|
||||
3. **Check for data issues**:
|
||||
```python
|
||||
# Verify no empty prompts or responses
|
||||
for sample in dataset:
|
||||
assert len(sample["prompt"]) > 0
|
||||
```
|
||||
|
||||
4. **Use BF16 instead of FP16**:
|
||||
```bash
|
||||
--bf16 # More numerically stable
|
||||
```
|
||||
|
||||
#### Issue: Reward Collapse
|
||||
|
||||
**Symptoms**: Reward drops to zero, model outputs garbage
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase KL penalty**:
|
||||
```bash
|
||||
--kl-loss-coef 0.01 # Increase from 0.001
|
||||
```
|
||||
|
||||
2. **Reduce number of samples**:
|
||||
```bash
|
||||
--n-samples-per-prompt 4 # Reduce from 8
|
||||
```
|
||||
|
||||
3. **Verify reward function**:
|
||||
```python
|
||||
# Test reward function independently
|
||||
from custom_rm import reward_func
|
||||
sample = Sample(prompt="test", response="test response")
|
||||
reward = reward_func(args, sample)
|
||||
print(f"Reward: {reward}") # Should be reasonable
|
||||
```
|
||||
|
||||
### Async Training Issues
|
||||
|
||||
#### Issue: Async Training Not Supported with Colocate
|
||||
|
||||
**Symptoms**: Error when using `--colocate` with `train_async.py`
|
||||
|
||||
**Solution**: Colocated mode is NOT supported for async training. Use separate GPUs:
|
||||
```bash
|
||||
# Remove --colocate flag
|
||||
python train_async.py \
|
||||
--actor-num-gpus-per-node 4 \
|
||||
--rollout-num-gpus 4 \
|
||||
# No --colocate
|
||||
```
|
||||
|
||||
#### Issue: Stale Weights in Async Mode
|
||||
|
||||
**Symptoms**: Policy divergence, inconsistent behavior
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce async buffer size**:
|
||||
```bash
|
||||
--async-buffer-size 2 # Reduce from 4
|
||||
```
|
||||
|
||||
2. **Increase weight update frequency**:
|
||||
```bash
|
||||
--update-weights-interval 1 # Sync every rollout
|
||||
```
|
||||
|
||||
### Multi-Turn Training Issues
|
||||
|
||||
#### Issue: Tool Responses Included in Loss
|
||||
|
||||
**Symptoms**: Model learns to output tool responses verbatim
|
||||
|
||||
**Solution**: Properly set loss mask in custom generate function:
|
||||
```python
|
||||
def build_loss_mask(sample):
|
||||
"""Create loss mask that excludes tool responses."""
|
||||
mask = []
|
||||
for i, token in enumerate(sample.tokens):
|
||||
if is_tool_response(token, sample.metadata):
|
||||
mask.append(0) # Don't compute loss
|
||||
else:
|
||||
mask.append(1) # Compute loss
|
||||
return mask
|
||||
```
|
||||
|
||||
#### Issue: Multi-Turn Context Too Long
|
||||
|
||||
**Symptoms**: OOM or truncation in multi-turn conversations
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Limit conversation history**:
|
||||
```python
|
||||
# In custom generate function
|
||||
conversation = sample.prompt[-10:] # Keep last 10 turns
|
||||
```
|
||||
|
||||
2. **Increase context length**:
|
||||
```bash
|
||||
--sglang-context-length 16384
|
||||
```
|
||||
|
||||
### Checkpoint Issues
|
||||
|
||||
#### Issue: Checkpoint Loading Fails
|
||||
|
||||
**Symptoms**: Cannot load saved checkpoint
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify checkpoint path**:
|
||||
```bash
|
||||
ls -la /path/to/checkpoint/
|
||||
```
|
||||
|
||||
2. **Check parallelism matches**:
|
||||
```bash
|
||||
# Checkpoint was saved with TP=2, must load with TP=2
|
||||
--tensor-model-parallel-size 2
|
||||
```
|
||||
|
||||
3. **Convert HuggingFace to Megatron** (if needed):
|
||||
```bash
|
||||
python tools/convert_hf_to_megatron.py \
|
||||
--hf_model_path /path/to/hf/model \
|
||||
--save_path /path/to/megatron/checkpoint
|
||||
```
|
||||
|
||||
### Debugging Tips
|
||||
|
||||
#### Enable Verbose Logging
|
||||
|
||||
```bash
|
||||
--log-level DEBUG
|
||||
export SLIME_DEBUG=1
|
||||
```
|
||||
|
||||
#### Check GPU Utilization
|
||||
|
||||
```bash
|
||||
watch -n 1 nvidia-smi
|
||||
```
|
||||
|
||||
#### Monitor Training
|
||||
|
||||
```bash
|
||||
tensorboard --logdir outputs/
|
||||
```
|
||||
|
||||
#### Test Custom Functions Independently
|
||||
|
||||
```python
|
||||
# Test reward function
|
||||
import asyncio
|
||||
from custom_rm import reward_func
|
||||
|
||||
async def test():
|
||||
sample = Sample(prompt="test", response="test", label="expected")
|
||||
reward = await reward_func(args, sample)
|
||||
print(f"Reward: {reward}")
|
||||
|
||||
asyncio.run(test())
|
||||
```
|
||||
|
||||
## Constraint Reference
|
||||
|
||||
Key constraint to remember:
|
||||
|
||||
```
|
||||
rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rollout
|
||||
```
|
||||
|
||||
Example: `32 × 8 = 256 × 1`
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub Issues: https://github.com/THUDM/slime/issues
|
||||
- Documentation: https://thudm.github.io/slime/
|
||||
- Examples: `examples/` directory
|
||||
|
|
@ -1,361 +0,0 @@
|
|||
---
|
||||
name: distributed-llm-pretraining-torchtitan
|
||||
description: Provides PyTorch-native distributed LLM pretraining using torchtitan with 4D parallelism (FSDP2, TP, PP, CP). Use when pretraining Llama 3.1, DeepSeek V3, or custom models at scale from 8 to 512+ GPUs with Float8, torch.compile, and distributed checkpointing.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [torch>=2.6.0, torchtitan>=0.2.0, torchao>=0.5.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Model Architecture, Distributed Training, TorchTitan, FSDP2, Tensor Parallel, Pipeline Parallel, Context Parallel, Float8, Llama, Pretraining]
|
||||
|
||||
---
|
||||
|
||||
# TorchTitan - PyTorch Native Distributed LLM Pretraining
|
||||
|
||||
## Quick start
|
||||
|
||||
TorchTitan is PyTorch's official platform for large-scale LLM pretraining with composable 4D parallelism (FSDP2, TP, PP, CP), achieving 65%+ speedups over baselines on H100 GPUs.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
# From PyPI (stable)
|
||||
pip install torchtitan
|
||||
|
||||
# From source (latest features, requires PyTorch nightly)
|
||||
git clone https://github.com/pytorch/torchtitan
|
||||
cd torchtitan
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**Download tokenizer**:
|
||||
```bash
|
||||
# Get HF token from https://huggingface.co/settings/tokens
|
||||
python scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets tokenizer --hf_token=...
|
||||
```
|
||||
|
||||
**Start training on 8 GPUs**:
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Pretrain Llama 3.1 8B on single node
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Single Node Pretraining:
|
||||
- [ ] Step 1: Download tokenizer
|
||||
- [ ] Step 2: Configure training
|
||||
- [ ] Step 3: Launch training
|
||||
- [ ] Step 4: Monitor and checkpoint
|
||||
```
|
||||
|
||||
**Step 1: Download tokenizer**
|
||||
|
||||
```bash
|
||||
python scripts/download_hf_assets.py \
|
||||
--repo_id meta-llama/Llama-3.1-8B \
|
||||
--assets tokenizer \
|
||||
--hf_token=YOUR_HF_TOKEN
|
||||
```
|
||||
|
||||
**Step 2: Configure training**
|
||||
|
||||
Edit or create a TOML config file:
|
||||
|
||||
```toml
|
||||
# llama3_8b_custom.toml
|
||||
[job]
|
||||
dump_folder = "./outputs"
|
||||
description = "Llama 3.1 8B training"
|
||||
|
||||
[model]
|
||||
name = "llama3"
|
||||
flavor = "8B"
|
||||
hf_assets_path = "./assets/hf/Llama-3.1-8B"
|
||||
|
||||
[optimizer]
|
||||
name = "AdamW"
|
||||
lr = 3e-4
|
||||
|
||||
[lr_scheduler]
|
||||
warmup_steps = 200
|
||||
|
||||
[training]
|
||||
local_batch_size = 2
|
||||
seq_len = 8192
|
||||
max_norm = 1.0
|
||||
steps = 1000
|
||||
dataset = "c4"
|
||||
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = -1 # Use all GPUs for FSDP
|
||||
|
||||
[activation_checkpoint]
|
||||
mode = "selective"
|
||||
selective_ac_option = "op"
|
||||
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
```
|
||||
|
||||
**Step 3: Launch training**
|
||||
|
||||
```bash
|
||||
# 8 GPUs on single node
|
||||
CONFIG_FILE="./llama3_8b_custom.toml" ./run_train.sh
|
||||
|
||||
# Or explicitly with torchrun
|
||||
torchrun --nproc_per_node=8 \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_8b_custom.toml
|
||||
```
|
||||
|
||||
**Step 4: Monitor and checkpoint**
|
||||
|
||||
TensorBoard logs are saved to `./outputs/tb/`:
|
||||
```bash
|
||||
tensorboard --logdir ./outputs/tb
|
||||
```
|
||||
|
||||
### Workflow 2: Multi-node training with SLURM
|
||||
|
||||
```
|
||||
Multi-Node Training:
|
||||
- [ ] Step 1: Configure parallelism for scale
|
||||
- [ ] Step 2: Set up SLURM script
|
||||
- [ ] Step 3: Submit job
|
||||
- [ ] Step 4: Resume from checkpoint
|
||||
```
|
||||
|
||||
**Step 1: Configure parallelism for scale**
|
||||
|
||||
For 70B model on 256 GPUs (32 nodes):
|
||||
```toml
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = 32 # FSDP across 32 ranks
|
||||
tensor_parallel_degree = 8 # TP within node
|
||||
pipeline_parallel_degree = 1 # No PP for 70B
|
||||
context_parallel_degree = 1 # Increase for long sequences
|
||||
```
|
||||
|
||||
**Step 2: Set up SLURM script**
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=llama70b
|
||||
#SBATCH --nodes=32
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gpus-per-node=8
|
||||
|
||||
srun torchrun \
|
||||
--nnodes=32 \
|
||||
--nproc_per_node=8 \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_70b.toml
|
||||
```
|
||||
|
||||
**Step 3: Submit job**
|
||||
|
||||
```bash
|
||||
sbatch multinode_trainer.slurm
|
||||
```
|
||||
|
||||
**Step 4: Resume from checkpoint**
|
||||
|
||||
Training auto-resumes if checkpoint exists in configured folder.
|
||||
|
||||
### Workflow 3: Enable Float8 training for H100s
|
||||
|
||||
Float8 provides 30-50% speedup on H100 GPUs.
|
||||
|
||||
```
|
||||
Float8 Training:
|
||||
- [ ] Step 1: Install torchao
|
||||
- [ ] Step 2: Configure Float8
|
||||
- [ ] Step 3: Launch with compile
|
||||
```
|
||||
|
||||
**Step 1: Install torchao**
|
||||
|
||||
```bash
|
||||
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
|
||||
```
|
||||
|
||||
**Step 2: Configure Float8**
|
||||
|
||||
Add to your TOML config:
|
||||
```toml
|
||||
[model]
|
||||
converters = ["quantize.linear.float8"]
|
||||
|
||||
[quantize.linear.float8]
|
||||
enable_fsdp_float8_all_gather = true
|
||||
precompute_float8_dynamic_scale_for_fsdp = true
|
||||
filter_fqns = ["output"] # Exclude output layer
|
||||
|
||||
[compile]
|
||||
enable = true
|
||||
components = ["model", "loss"]
|
||||
```
|
||||
|
||||
**Step 3: Launch with compile**
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.enable_fsdp_float8_all_gather \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
### Workflow 4: 4D parallelism for 405B models
|
||||
|
||||
```
|
||||
4D Parallelism (FSDP + TP + PP + CP):
|
||||
- [ ] Step 1: Create seed checkpoint
|
||||
- [ ] Step 2: Configure 4D parallelism
|
||||
- [ ] Step 3: Launch on 512 GPUs
|
||||
```
|
||||
|
||||
**Step 1: Create seed checkpoint**
|
||||
|
||||
Required for consistent initialization across PP stages:
|
||||
```bash
|
||||
NGPU=1 CONFIG_FILE=./llama3_405b.toml ./run_train.sh \
|
||||
--checkpoint.enable \
|
||||
--checkpoint.create_seed_checkpoint \
|
||||
--parallelism.data_parallel_shard_degree 1 \
|
||||
--parallelism.tensor_parallel_degree 1 \
|
||||
--parallelism.pipeline_parallel_degree 1
|
||||
```
|
||||
|
||||
**Step 2: Configure 4D parallelism**
|
||||
|
||||
```toml
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = 8 # FSDP
|
||||
tensor_parallel_degree = 8 # TP within node
|
||||
pipeline_parallel_degree = 8 # PP across nodes
|
||||
context_parallel_degree = 1 # CP for long sequences
|
||||
|
||||
[training]
|
||||
local_batch_size = 32
|
||||
seq_len = 8192
|
||||
```
|
||||
|
||||
**Step 3: Launch on 512 GPUs**
|
||||
|
||||
```bash
|
||||
# 64 nodes x 8 GPUs = 512 GPUs
|
||||
srun torchrun --nnodes=64 --nproc_per_node=8 \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_405b.toml
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use TorchTitan when:**
|
||||
- Pretraining LLMs from scratch (8B to 405B+)
|
||||
- Need PyTorch-native solution without third-party dependencies
|
||||
- Require composable 4D parallelism (FSDP2, TP, PP, CP)
|
||||
- Training on H100s with Float8 support
|
||||
- Want interoperable checkpoints with torchtune/HuggingFace
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Megatron-LM**: Maximum performance for NVIDIA-only deployments
|
||||
- **DeepSpeed**: Broader ZeRO optimization ecosystem, inference support
|
||||
- **Axolotl/TRL**: Fine-tuning rather than pretraining
|
||||
- **LitGPT**: Educational, smaller-scale training
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Out of memory on large models**
|
||||
|
||||
Enable activation checkpointing and reduce batch size:
|
||||
```toml
|
||||
[activation_checkpoint]
|
||||
mode = "full" # Instead of "selective"
|
||||
|
||||
[training]
|
||||
local_batch_size = 1
|
||||
```
|
||||
|
||||
Or use gradient accumulation:
|
||||
```toml
|
||||
[training]
|
||||
local_batch_size = 1
|
||||
global_batch_size = 32 # Accumulates gradients
|
||||
```
|
||||
|
||||
**Issue: TP causes high memory with async collectives**
|
||||
|
||||
Set environment variable:
|
||||
```bash
|
||||
export TORCH_NCCL_AVOID_RECORD_STREAMS=1
|
||||
```
|
||||
|
||||
**Issue: Float8 training not faster**
|
||||
|
||||
Float8 only benefits large GEMMs. Filter small layers:
|
||||
```toml
|
||||
[quantize.linear.float8]
|
||||
filter_fqns = ["attention.wk", "attention.wv", "output", "auto_filter_small_kn"]
|
||||
```
|
||||
|
||||
**Issue: Checkpoint loading fails after parallelism change**
|
||||
|
||||
Use DCP's resharding capability:
|
||||
```bash
|
||||
# Convert sharded checkpoint to single file
|
||||
python -m torch.distributed.checkpoint.format_utils \
|
||||
dcp_to_torch checkpoint/step-1000 checkpoint.pt
|
||||
```
|
||||
|
||||
**Issue: Pipeline parallelism initialization**
|
||||
|
||||
Create seed checkpoint first (see Workflow 4, Step 1).
|
||||
|
||||
## Supported models
|
||||
|
||||
| Model | Sizes | Status |
|
||||
|-------|-------|--------|
|
||||
| Llama 3.1 | 8B, 70B, 405B | Production |
|
||||
| Llama 4 | Various | Experimental |
|
||||
| DeepSeek V3 | 16B, 236B, 671B (MoE) | Experimental |
|
||||
| GPT-OSS | 20B, 120B (MoE) | Experimental |
|
||||
| Qwen 3 | Various | Experimental |
|
||||
| Flux | Diffusion | Experimental |
|
||||
|
||||
## Performance benchmarks (H100)
|
||||
|
||||
| Model | GPUs | Parallelism | TPS/GPU | Techniques |
|
||||
|-------|------|-------------|---------|------------|
|
||||
| Llama 8B | 8 | FSDP | 5,762 | Baseline |
|
||||
| Llama 8B | 8 | FSDP+compile+FP8 | 8,532 | +48% |
|
||||
| Llama 70B | 256 | FSDP+TP+AsyncTP | 876 | 2D parallel |
|
||||
| Llama 405B | 512 | FSDP+TP+PP | 128 | 3D parallel |
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**FSDP2 configuration**: See [references/fsdp.md](references/fsdp.md) for detailed FSDP2 vs FSDP1 comparison and ZeRO equivalents.
|
||||
|
||||
**Float8 training**: See [references/float8.md](references/float8.md) for tensorwise vs rowwise scaling recipes.
|
||||
|
||||
**Checkpointing**: See [references/checkpoint.md](references/checkpoint.md) for HuggingFace conversion and async checkpointing.
|
||||
|
||||
**Adding custom models**: See [references/custom-models.md](references/custom-models.md) for TrainSpec protocol.
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub: https://github.com/pytorch/torchtitan
|
||||
- Paper: https://arxiv.org/abs/2410.06511
|
||||
- ICLR 2025: https://iclr.cc/virtual/2025/poster/29620
|
||||
- PyTorch Forum: https://discuss.pytorch.org/c/distributed/torchtitan/44
|
||||
|
||||
|
|
@ -1,181 +0,0 @@
|
|||
# Checkpointing in TorchTitan
|
||||
|
||||
TorchTitan uses PyTorch Distributed Checkpoint (DCP) for fault-tolerant, interoperable checkpointing.
|
||||
|
||||
## Basic Configuration
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
```
|
||||
|
||||
## Save Model Only (Smaller Checkpoints)
|
||||
|
||||
Exclude optimizer state and training metadata:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
last_save_model_only = true
|
||||
export_dtype = "bfloat16" # Optional: export in lower precision
|
||||
```
|
||||
|
||||
## Excluding Keys from Loading
|
||||
|
||||
Partial checkpoint loading for modified settings:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
exclude_from_loading = ["data_loader", "lr_scheduler"]
|
||||
```
|
||||
|
||||
CLI equivalent:
|
||||
```bash
|
||||
--checkpoint.exclude_from_loading data_loader,lr_scheduler
|
||||
```
|
||||
|
||||
## Creating Seed Checkpoints
|
||||
|
||||
Required for Pipeline Parallelism to ensure consistent initialization:
|
||||
|
||||
```bash
|
||||
NGPU=1 CONFIG_FILE=<path_to_config> ./run_train.sh \
|
||||
--checkpoint.enable \
|
||||
--checkpoint.create_seed_checkpoint \
|
||||
--parallelism.data_parallel_replicate_degree 1 \
|
||||
--parallelism.data_parallel_shard_degree 1 \
|
||||
--parallelism.tensor_parallel_degree 1 \
|
||||
--parallelism.pipeline_parallel_degree 1 \
|
||||
--parallelism.context_parallel_degree 1 \
|
||||
--parallelism.expert_parallel_degree 1
|
||||
```
|
||||
|
||||
This initializes on single CPU for reproducible initialization across any GPU count.
|
||||
|
||||
## Async Checkpointing
|
||||
|
||||
Reduce checkpoint overhead with async writes:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
async_mode = "async" # Options: "disabled", "async", "async_with_pinned_mem"
|
||||
```
|
||||
|
||||
## HuggingFace Conversion
|
||||
|
||||
### During Training
|
||||
|
||||
Save directly in HuggingFace format:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
last_save_in_hf = true
|
||||
last_save_model_only = true
|
||||
```
|
||||
|
||||
Load from HuggingFace:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
initial_load_in_hf = true
|
||||
|
||||
[model]
|
||||
hf_assets_path = "./path/to/hf/checkpoint"
|
||||
```
|
||||
|
||||
### Offline Conversion
|
||||
|
||||
Convert without running training:
|
||||
|
||||
```bash
|
||||
# HuggingFace -> TorchTitan
|
||||
python ./scripts/checkpoint_conversion/convert_from_hf.py \
|
||||
<input_dir> <output_dir> \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
|
||||
# TorchTitan -> HuggingFace
|
||||
python ./scripts/checkpoint_conversion/convert_to_hf.py \
|
||||
<input_dir> <output_dir> \
|
||||
--hf_assets_path ./assets/hf/Llama3.1-8B \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
```
|
||||
|
||||
### Example
|
||||
|
||||
```bash
|
||||
python ./scripts/convert_from_hf.py \
|
||||
~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ \
|
||||
./initial_load_path/ \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
```
|
||||
|
||||
## Converting to Single .pt File
|
||||
|
||||
Convert DCP sharded checkpoint to single PyTorch file:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.checkpoint.format_utils \
|
||||
dcp_to_torch \
|
||||
torchtitan/outputs/checkpoint/step-1000 \
|
||||
checkpoint.pt
|
||||
```
|
||||
|
||||
## Checkpoint Structure
|
||||
|
||||
DCP saves sharded checkpoints that can be resharded for different parallelism configurations:
|
||||
|
||||
```
|
||||
checkpoint/
|
||||
├── step-500/
|
||||
│ ├── .metadata
|
||||
│ ├── __0_0.distcp
|
||||
│ ├── __0_1.distcp
|
||||
│ └── ...
|
||||
└── step-1000/
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Resume Training
|
||||
|
||||
Training auto-resumes from the latest checkpoint in the configured folder. To resume from a specific step:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
load_step = 500 # Resume from step 500
|
||||
```
|
||||
|
||||
## Interoperability with TorchTune
|
||||
|
||||
Checkpoints saved with `last_save_model_only = true` can be loaded directly into [torchtune](https://github.com/pytorch/torchtune) for fine-tuning.
|
||||
|
||||
## Full Configuration Example
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
load_step = -1 # -1 = latest, or specify step number
|
||||
last_save_model_only = true
|
||||
export_dtype = "bfloat16"
|
||||
async_mode = "async"
|
||||
exclude_from_loading = []
|
||||
last_save_in_hf = false
|
||||
initial_load_in_hf = false
|
||||
create_seed_checkpoint = false
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Large models**: Use `async_mode = "async"` to overlap checkpoint saves with training
|
||||
2. **Fine-tuning export**: Enable `last_save_model_only` and `export_dtype = "bfloat16"` for smaller files
|
||||
3. **Pipeline parallelism**: Always create seed checkpoint first
|
||||
4. **Debugging**: Save frequent checkpoints during development, reduce for production
|
||||
5. **HF interop**: Use conversion scripts for offline conversion, direct save/load for training workflows
|
||||
|
|
@ -1,258 +0,0 @@
|
|||
# Adding Custom Models to TorchTitan
|
||||
|
||||
This guide explains how to add a new model to TorchTitan following the established patterns.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
torchtitan/models/your_model/
|
||||
├── model/
|
||||
│ ├── __init__.py
|
||||
│ ├── args.py # Model arguments
|
||||
│ ├── model.py # Model definition
|
||||
│ └── state_dict_adapter.py # HF conversion (optional)
|
||||
├── infra/
|
||||
│ ├── __init__.py
|
||||
│ ├── parallelize.py # TP, FSDP, compile application
|
||||
│ └── pipeline.py # PP application (optional)
|
||||
├── train_configs/
|
||||
│ ├── debug_model.toml
|
||||
│ └── your_model_XB.toml
|
||||
├── __init__.py # TrainSpec registration
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Step 1: Define Model Arguments
|
||||
|
||||
Inherit from `BaseModelArgs`:
|
||||
|
||||
```python
|
||||
# model/args.py
|
||||
from torchtitan.protocols.model import BaseModelArgs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class YourModelArgs(BaseModelArgs):
|
||||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
vocab_size: int = 128256
|
||||
|
||||
def get_nparams_and_flops(self, seq_len: int) -> tuple[int, int]:
|
||||
"""Return (num_params, flops_per_token) for throughput calculation."""
|
||||
nparams = self.vocab_size * self.dim + ... # Calculate params
|
||||
flops = 6 * nparams # Approximate: 6 * params for forward+backward
|
||||
return nparams, flops
|
||||
|
||||
def update_from_config(self, job_config) -> "YourModelArgs":
|
||||
"""Update args from training config."""
|
||||
# Override specific args from job_config if needed
|
||||
return self
|
||||
```
|
||||
|
||||
## Step 2: Define Model
|
||||
|
||||
Inherit from `ModelProtocol`:
|
||||
|
||||
```python
|
||||
# model/model.py
|
||||
import torch.nn as nn
|
||||
from torchtitan.protocols.model import ModelProtocol
|
||||
from .args import YourModelArgs
|
||||
|
||||
class YourModel(ModelProtocol):
|
||||
def __init__(self, args: YourModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = nn.ModuleDict({
|
||||
str(i): TransformerBlock(args) for i in range(args.n_layers)
|
||||
})
|
||||
self.norm = RMSNorm(args.dim)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
h = self.tok_embeddings(tokens)
|
||||
for layer in self.layers.values():
|
||||
h = layer(h)
|
||||
h = self.norm(h)
|
||||
return self.output(h)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights recursively."""
|
||||
for module in self.modules():
|
||||
if hasattr(module, 'init_weights') and module is not self:
|
||||
module.init_weights()
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=0.02)
|
||||
```
|
||||
|
||||
**Important guidelines**:
|
||||
- Write single-device model code (parallelism applied externally)
|
||||
- Use `nn.ModuleDict` for layers (preserves FQNs when deleting for PP)
|
||||
- Make input/output layers optional for PP compatibility
|
||||
- Define `init_weights()` recursively
|
||||
|
||||
## Step 3: Parallelize Function
|
||||
|
||||
```python
|
||||
# infra/parallelize.py
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
|
||||
def parallelize_your_model(
|
||||
model: YourModel,
|
||||
world_mesh: DeviceMesh,
|
||||
parallel_dims: ParallelDims,
|
||||
job_config: JobConfig,
|
||||
):
|
||||
# Apply in this order: TP -> AC -> compile -> FSDP
|
||||
|
||||
# 1. Tensor Parallelism
|
||||
if parallel_dims.tp_enabled:
|
||||
apply_tp(model, world_mesh["tp"], job_config)
|
||||
|
||||
# 2. Activation Checkpointing
|
||||
if job_config.activation_checkpoint.mode == "full":
|
||||
apply_ac(model, job_config)
|
||||
|
||||
# 3. torch.compile
|
||||
if job_config.compile.enable:
|
||||
model = torch.compile(model)
|
||||
|
||||
# 4. FSDP
|
||||
if parallel_dims.dp_enabled:
|
||||
apply_fsdp(model, world_mesh["dp"], job_config)
|
||||
|
||||
return model
|
||||
```
|
||||
|
||||
## Step 4: Create TrainSpec
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
from torchtitan.protocols.train_spec import TrainSpec, register_train_spec
|
||||
from .model.model import YourModel
|
||||
from .model.args import YourModelArgs
|
||||
from .infra.parallelize import parallelize_your_model
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"8B": YourModelArgs(dim=4096, n_layers=32, n_heads=32),
|
||||
"70B": YourModelArgs(dim=8192, n_layers=80, n_heads=64),
|
||||
}
|
||||
|
||||
def get_train_spec(flavor: str) -> TrainSpec:
|
||||
return TrainSpec(
|
||||
model_cls=YourModel,
|
||||
model_args=MODEL_CONFIGS[flavor],
|
||||
parallelize_fn=parallelize_your_model,
|
||||
pipeline_fn=None, # Or your_pipeline_fn for PP
|
||||
build_optimizer_fn=build_optimizer, # Reuse existing
|
||||
build_lr_scheduler_fn=build_lr_scheduler, # Reuse existing
|
||||
build_dataloader_fn=build_dataloader, # Reuse existing
|
||||
build_tokenizer_fn=build_tokenizer, # Reuse existing
|
||||
build_loss_fn=build_loss, # Reuse existing
|
||||
state_dict_adapter=None, # Or YourStateDictAdapter
|
||||
)
|
||||
|
||||
# Register so train.py can find it
|
||||
register_train_spec("your_model", get_train_spec)
|
||||
```
|
||||
|
||||
## Step 5: State Dict Adapter (Optional)
|
||||
|
||||
For HuggingFace checkpoint conversion:
|
||||
|
||||
```python
|
||||
# model/state_dict_adapter.py
|
||||
from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter
|
||||
|
||||
class YourStateDictAdapter(BaseStateDictAdapter):
|
||||
def to_hf(self, state_dict: dict) -> dict:
|
||||
"""Convert torchtitan state dict to HF format."""
|
||||
hf_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
hf_key = self._convert_key_to_hf(key)
|
||||
hf_state_dict[hf_key] = value
|
||||
return hf_state_dict
|
||||
|
||||
def from_hf(self, state_dict: dict) -> dict:
|
||||
"""Convert HF state dict to torchtitan format."""
|
||||
tt_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
tt_key = self._convert_key_from_hf(key)
|
||||
tt_state_dict[tt_key] = value
|
||||
return tt_state_dict
|
||||
```
|
||||
|
||||
## Step 6: Training Config
|
||||
|
||||
```toml
|
||||
# train_configs/your_model_8b.toml
|
||||
[job]
|
||||
dump_folder = "./outputs"
|
||||
description = "Your Model 8B training"
|
||||
|
||||
[model]
|
||||
name = "your_model"
|
||||
flavor = "8B"
|
||||
|
||||
[optimizer]
|
||||
name = "AdamW"
|
||||
lr = 3e-4
|
||||
|
||||
[training]
|
||||
local_batch_size = 2
|
||||
seq_len = 8192
|
||||
steps = 1000
|
||||
dataset = "c4"
|
||||
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = -1
|
||||
tensor_parallel_degree = 1
|
||||
```
|
||||
|
||||
## Step 7: Register Model
|
||||
|
||||
Add to `torchtitan/models/__init__.py`:
|
||||
|
||||
```python
|
||||
from .your_model import get_train_spec as get_your_model_train_spec
|
||||
|
||||
MODEL_REGISTRY["your_model"] = get_your_model_train_spec
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
### Numerics Test
|
||||
|
||||
Compare output with HuggingFace implementation:
|
||||
|
||||
```python
|
||||
def test_numerics():
|
||||
# Load same checkpoint into both implementations
|
||||
tt_model = YourModel(args).load_checkpoint(...)
|
||||
hf_model = HFYourModel.from_pretrained(...)
|
||||
|
||||
# Compare outputs
|
||||
input_ids = torch.randint(0, vocab_size, (1, 128))
|
||||
tt_output = tt_model(input_ids)
|
||||
hf_output = hf_model(input_ids).logits
|
||||
|
||||
torch.testing.assert_close(tt_output, hf_output, atol=1e-4, rtol=1e-4)
|
||||
```
|
||||
|
||||
### Loss Convergence
|
||||
|
||||
Compare loss curves with verified baseline (see `docs/converging.md`).
|
||||
|
||||
### Performance Benchmark
|
||||
|
||||
Add benchmark config to `benchmarks/` folder.
|
||||
|
||||
## Guiding Principles
|
||||
|
||||
1. **Readability over flexibility**: Don't over-abstract
|
||||
2. **Minimal model changes**: Parallelism applied externally
|
||||
3. **Clean, minimal codebase**: Reuse existing components where possible
|
||||
4. **Single-device semantics**: Model code should work on single GPU
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
# Float8 Training in TorchTitan
|
||||
|
||||
Float8 training provides substantial speedups for models where GEMMs are large enough that the FP8 tensorcore speedup outweighs dynamic quantization overhead.
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
- NVIDIA H100 or newer GPUs (FP8 Tensor Cores)
|
||||
- Blackwell GPUs for MXFP8 training
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
|
||||
```
|
||||
|
||||
## Usage: Tensorwise Scaling
|
||||
|
||||
Standard Float8 with tensorwise dynamic scaling:
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.enable_fsdp_float8_all_gather \
|
||||
--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
### Key Arguments
|
||||
|
||||
| Argument | Description |
|
||||
|----------|-------------|
|
||||
| `--model.converters="quantize.linear.float8"` | Swap `nn.Linear` with `Float8Linear` |
|
||||
| `--quantize.linear.float8.enable_fsdp_float8_all_gather` | Communicate in float8 to save bandwidth |
|
||||
| `--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp` | Single all-reduce for all AMAX/scales |
|
||||
| `--compile.enable` | Required - fuses float8 scaling/casting kernels |
|
||||
|
||||
## Usage: Rowwise Scaling
|
||||
|
||||
Higher accuracy than tensorwise scaling:
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.recipe_name rowwise \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
## Filtering Layers
|
||||
|
||||
Not all layers benefit from Float8. Filter small layers:
|
||||
|
||||
```bash
|
||||
--quantize.linear.float8.filter_fqns="attention.wk,attention.wv,output"
|
||||
```
|
||||
|
||||
### Auto-filtering
|
||||
|
||||
Automatically skip layers too small to benefit:
|
||||
|
||||
```bash
|
||||
--quantize.linear.float8.filter_fqns="auto_filter_small_kn"
|
||||
```
|
||||
|
||||
Thresholds based on H100 microbenchmarks where speedup > overhead.
|
||||
|
||||
## TOML Configuration
|
||||
|
||||
```toml
|
||||
[model]
|
||||
converters = ["quantize.linear.float8"]
|
||||
|
||||
[quantize.linear.float8]
|
||||
enable_fsdp_float8_all_gather = true
|
||||
precompute_float8_dynamic_scale_for_fsdp = true
|
||||
filter_fqns = ["output", "auto_filter_small_kn"]
|
||||
|
||||
[compile]
|
||||
enable = true
|
||||
components = ["model", "loss"]
|
||||
```
|
||||
|
||||
## How Float8 Works with Distributed Training
|
||||
|
||||
### Single Device
|
||||
|
||||
Cast input and weight to float8 inside forward before calling `torch._scaled_mm`:
|
||||
|
||||
```python
|
||||
# Float8 matmul requires scales
|
||||
torch._scaled_mm(input_fp8, weight_fp8, scale_a=scale_input, scale_b=scale_weight)
|
||||
```
|
||||
|
||||
### FSDP + Float8
|
||||
|
||||
1. Cast sharded high-precision weights (1/N per rank) to float8
|
||||
2. Perform float8 all-gather (saves bandwidth vs bf16/fp32)
|
||||
3. Communicate `max(abs)` across ranks for scale computation
|
||||
4. At forward start, have unsharded float8 weights ready
|
||||
|
||||
**Net benefit**: Float8 all-gather + amax communication can beat bf16/fp32 all-gather, depending on world size and message size.
|
||||
|
||||
### TP + Float8
|
||||
|
||||
- **Input**: Cast sharded input to float8, all-gather in float8
|
||||
- **Weights**: Communicate `max(abs)` for sharded weights
|
||||
- **Matmul**: Float8 input (unsharded) x float8 weight (sharded) with global scales
|
||||
|
||||
## Scaling Strategies
|
||||
|
||||
| Strategy | Status | Description |
|
||||
|----------|--------|-------------|
|
||||
| Tensorwise dynamic | Stable | Single scale per tensor |
|
||||
| Rowwise dynamic | Alpha | Scale per row, higher accuracy |
|
||||
|
||||
## Performance Gains
|
||||
|
||||
From benchmarks on H100:
|
||||
|
||||
| Configuration | TPS/GPU | vs Baseline |
|
||||
|---------------|---------|-------------|
|
||||
| FSDP only | 5,762 | - |
|
||||
| FSDP + compile | 6,667 | +16% |
|
||||
| FSDP + compile + Float8 | 8,532 | +48% |
|
||||
|
||||
## Determining Float8 Benefit
|
||||
|
||||
Check [torchao microbenchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) for forward+backward pass speedups on "layer norm => linear => sigmoid" for different M,N,K sizes.
|
||||
|
||||
Rule of thumb: GEMMs with K,N > 4096 typically benefit from Float8.
|
||||
|
||||
## MXFP8 Training (Blackwell)
|
||||
|
||||
For NVIDIA Blackwell GPUs, TorchTitan supports MXFP8 (Microscaling FP8) for both dense and MoE models. See [docs/mxfp8.md](https://github.com/pytorch/torchtitan/blob/main/docs/mxfp8.md) for details.
|
||||
|
|
@ -1,126 +0,0 @@
|
|||
# FSDP2 in TorchTitan
|
||||
|
||||
## Why FSDP2?
|
||||
|
||||
FSDP2 is a rewrite of PyTorch's Fully Sharded Data Parallel (FSDP) API, removing the `FlatParameter` abstraction for better composability and simpler implementation.
|
||||
|
||||
### Key improvements over FSDP1
|
||||
|
||||
- **DTensor-based sharding**: Sharded parameters are `DTensor`s on dim-0, enabling easy manipulation and communication-free sharded state dicts
|
||||
- **Better memory management**: Deterministic and lower GPU memory (7% reduction) by avoiding `recordStream`
|
||||
- **Simplified API**: Fewer arguments, no wrapper class
|
||||
|
||||
### Performance
|
||||
|
||||
On Llama-7B with 8x H100s, FSDP2 achieves higher MFU with 7% lower peak memory than FSDP1, matching the same loss curve.
|
||||
|
||||
## API Reference
|
||||
|
||||
```python
|
||||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, OffloadPolicy
|
||||
|
||||
@contract(state_cls=FSDPState)
|
||||
def fully_shard(
|
||||
module: nn.Module,
|
||||
*,
|
||||
mesh: Optional[DeviceMesh] = None,
|
||||
reshard_after_forward: Union[bool, int] = True,
|
||||
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
|
||||
offload_policy: OffloadPolicy = OffloadPolicy(),
|
||||
) -> nn.Module:
|
||||
```
|
||||
|
||||
## Sharding Strategies (ZeRO Equivalents)
|
||||
|
||||
| FSDP2 Configuration | FSDP1 Equivalent | DeepSpeed |
|
||||
|---------------------|------------------|-----------|
|
||||
| 1D mesh + `reshard_after_forward=True` | FULL_SHARD | ZeRO-3 |
|
||||
| 1D mesh + `reshard_after_forward=False` | SHARD_GRAD_OP | ZeRO-2 |
|
||||
| 2D mesh + `reshard_after_forward=True` | HYBRID_SHARD | MiCS |
|
||||
| 1D/2D mesh + `reshard_after_forward=8` (int) | - | ZeRO++ hpZ |
|
||||
|
||||
## Meta-Device Initialization
|
||||
|
||||
FSDP2 supports materializing tensors onto GPU _after_ sharding:
|
||||
|
||||
```python
|
||||
# Initialize on meta device (no memory)
|
||||
with torch.device("meta"):
|
||||
model = Transformer()
|
||||
|
||||
# Apply FSDP2 sharding
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
fully_shard(module)
|
||||
fully_shard(model)
|
||||
|
||||
# Parameters still on meta device
|
||||
for tensor in itertools.chain(model.parameters(), model.buffers()):
|
||||
assert tensor.device == torch.device("meta")
|
||||
|
||||
# Allocate sharded parameters on GPU
|
||||
model.to_empty(device="cuda")
|
||||
|
||||
# Initialize weights
|
||||
model.init_weights()
|
||||
```
|
||||
|
||||
## State Dict Differences
|
||||
|
||||
| Operation | FSDP1 | FSDP2 |
|
||||
|-----------|-------|-------|
|
||||
| `model.state_dict()` | Full state dict | Sharded state dict (no communication) |
|
||||
| `optim.state_dict()` | Local state dict | Sharded state dict (no communication) |
|
||||
| `summon_full_params()` | Supported | Use `DTensor` APIs like `full_tensor()` |
|
||||
| Gradient clipping | `FSDP.clip_grad_norm_()` | `nn.utils.clip_grad_norm_()` |
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
```python
|
||||
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.float32,
|
||||
output_dtype=torch.bfloat16,
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
|
||||
fully_shard(model, mp_policy=mp_policy)
|
||||
```
|
||||
|
||||
## HSDP (Hybrid Sharded Data Parallel)
|
||||
|
||||
For 2D parallelism with replication + sharding:
|
||||
|
||||
```python
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
# Replicate across 4 groups, shard within 8 GPUs each
|
||||
mesh = init_device_mesh("cuda", (4, 8), mesh_dim_names=("replicate", "shard"))
|
||||
|
||||
fully_shard(model, mesh=mesh)
|
||||
```
|
||||
|
||||
## Configuration in TorchTitan
|
||||
|
||||
```toml
|
||||
[parallelism]
|
||||
# FSDP sharding degree (-1 = auto, use all available GPUs)
|
||||
data_parallel_shard_degree = -1
|
||||
|
||||
# HSDP replication degree (1 = pure FSDP, >1 = HSDP)
|
||||
data_parallel_replicate_degree = 1
|
||||
```
|
||||
|
||||
## Removed Arguments from FSDP1
|
||||
|
||||
These FSDP1 arguments are no longer needed:
|
||||
|
||||
- `auto_wrap_policy`: Apply `fully_shard` directly to modules
|
||||
- `backward_prefetch`: Always uses BACKWARD_PRE
|
||||
- `param_init_fn`: Use meta-device initialization
|
||||
- `device_id`: Uses mesh's device automatically
|
||||
- `sync_module_states`: Not needed with DTensor
|
||||
- `limit_all_gathers`: New memory management doesn't need it
|
||||
- `use_orig_params`: Always true (no FlatParameter)
|
||||
|
|
@ -1,458 +0,0 @@
|
|||
---
|
||||
name: fine-tuning-with-trl
|
||||
description: Fine-tune LLMs using reinforcement learning with TRL - SFT for instruction tuning, DPO for preference alignment, PPO/GRPO for reward optimization, and reward model training. Use when need RLHF, align model with preferences, or train from human feedback. Works with HuggingFace Transformers.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [trl, transformers, datasets, peft, accelerate, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, TRL, Reinforcement Learning, Fine-Tuning, SFT, DPO, PPO, GRPO, RLHF, Preference Alignment, HuggingFace]
|
||||
|
||||
---
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
## Quick start
|
||||
|
||||
TRL provides post-training methods for aligning language models with human preferences.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install trl transformers datasets peft accelerate
|
||||
```
|
||||
|
||||
**Supervised Fine-Tuning** (instruction tuning):
|
||||
```python
|
||||
from trl import SFTTrainer
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset, # Prompt-completion pairs
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**DPO** (align with preferences):
|
||||
```python
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
|
||||
config = DPOConfig(output_dir="model-dpo", beta=0.1)
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=preference_dataset, # chosen/rejected pairs
|
||||
processing_class=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Full RLHF pipeline (SFT → Reward Model → PPO)
|
||||
|
||||
Complete pipeline from base model to human-aligned model.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
RLHF Training:
|
||||
- [ ] Step 1: Supervised fine-tuning (SFT)
|
||||
- [ ] Step 2: Train reward model
|
||||
- [ ] Step 3: PPO reinforcement learning
|
||||
- [ ] Step 4: Evaluate aligned model
|
||||
```
|
||||
|
||||
**Step 1: Supervised fine-tuning**
|
||||
|
||||
Train base model on instruction-following data:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
|
||||
# Load instruction dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Configure training
|
||||
training_args = SFTConfig(
|
||||
output_dir="Qwen2.5-0.5B-SFT",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=2e-5,
|
||||
logging_steps=10,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**Step 2: Train reward model**
|
||||
|
||||
Train model to predict human preferences:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
|
||||
# Load SFT model as base
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen2.5-0.5B-SFT",
|
||||
num_labels=1 # Single reward score
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen2.5-0.5B-SFT")
|
||||
|
||||
# Load preference data (chosen/rejected pairs)
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
# Configure training
|
||||
training_args = RewardConfig(
|
||||
output_dir="Qwen2.5-0.5B-Reward",
|
||||
per_device_train_batch_size=2,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5
|
||||
)
|
||||
|
||||
# Train reward model
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**Step 3: PPO reinforcement learning**
|
||||
|
||||
Optimize policy using reward model:
|
||||
|
||||
```bash
|
||||
python -m trl.scripts.ppo \
|
||||
--model_name_or_path Qwen2.5-0.5B-SFT \
|
||||
--reward_model_path Qwen2.5-0.5B-Reward \
|
||||
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
||||
--output_dir Qwen2.5-0.5B-PPO \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--total_episodes 10000
|
||||
```
|
||||
|
||||
**Step 4: Evaluate**
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Load aligned model
|
||||
generator = pipeline("text-generation", model="Qwen2.5-0.5B-PPO")
|
||||
|
||||
# Test
|
||||
prompt = "Explain quantum computing to a 10-year-old"
|
||||
output = generator(prompt, max_length=200)[0]["generated_text"]
|
||||
print(output)
|
||||
```
|
||||
|
||||
### Workflow 2: Simple preference alignment with DPO
|
||||
|
||||
Align model with preferences without reward model.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
DPO Training:
|
||||
- [ ] Step 1: Prepare preference dataset
|
||||
- [ ] Step 2: Configure DPO
|
||||
- [ ] Step 3: Train with DPOTrainer
|
||||
- [ ] Step 4: Evaluate alignment
|
||||
```
|
||||
|
||||
**Step 1: Prepare preference dataset**
|
||||
|
||||
Dataset format:
|
||||
```json
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"chosen": "The capital of France is Paris.",
|
||||
"rejected": "I don't know."
|
||||
}
|
||||
```
|
||||
|
||||
Load dataset:
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
# Or load your own
|
||||
# dataset = load_dataset("json", data_files="preferences.json")
|
||||
```
|
||||
|
||||
**Step 2: Configure DPO**
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
config = DPOConfig(
|
||||
output_dir="Qwen2.5-0.5B-DPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=5e-7,
|
||||
beta=0.1, # KL penalty strength
|
||||
max_prompt_length=512,
|
||||
max_length=1024,
|
||||
logging_steps=10
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: Train with DPOTrainer**
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**CLI alternative**:
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name argilla/Capybara-Preferences \
|
||||
--output_dir Qwen2.5-0.5B-DPO \
|
||||
--per_device_train_batch_size 4 \
|
||||
--learning_rate 5e-7 \
|
||||
--beta 0.1
|
||||
```
|
||||
|
||||
### Workflow 3: Memory-efficient online RL with GRPO
|
||||
|
||||
Train with reinforcement learning using minimal memory.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
GRPO Training:
|
||||
- [ ] Step 1: Define reward function
|
||||
- [ ] Step 2: Configure GRPO
|
||||
- [ ] Step 3: Train with GRPOTrainer
|
||||
```
|
||||
|
||||
**Step 1: Define reward function**
|
||||
|
||||
```python
|
||||
def reward_function(completions, **kwargs):
|
||||
"""
|
||||
Compute rewards for completions.
|
||||
|
||||
Args:
|
||||
completions: List of generated texts
|
||||
|
||||
Returns:
|
||||
List of reward scores (floats)
|
||||
"""
|
||||
rewards = []
|
||||
for completion in completions:
|
||||
# Example: reward based on length and unique words
|
||||
score = len(completion.split()) # Favor longer responses
|
||||
score += len(set(completion.lower().split())) # Reward unique words
|
||||
rewards.append(score)
|
||||
return rewards
|
||||
```
|
||||
|
||||
Or use a reward model:
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
reward_model = pipeline("text-classification", model="reward-model-path")
|
||||
|
||||
def reward_from_model(completions, prompts, **kwargs):
|
||||
# Combine prompt + completion
|
||||
full_texts = [p + c for p, c in zip(prompts, completions)]
|
||||
# Get reward scores
|
||||
results = reward_model(full_texts)
|
||||
return [r["score"] for r in results]
|
||||
```
|
||||
|
||||
**Step 2: Configure GRPO**
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
config = GRPOConfig(
|
||||
output_dir="Qwen2-GRPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5,
|
||||
num_generations=4, # Generate 4 completions per prompt
|
||||
max_new_tokens=128
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: Train with GRPOTrainer**
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load prompt-only dataset
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_function, # Your reward function
|
||||
args=config,
|
||||
train_dataset=dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**CLI**:
|
||||
```bash
|
||||
trl grpo \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--output_dir Qwen2-GRPO \
|
||||
--num_generations 4
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use TRL when:**
|
||||
- Need to align model with human preferences
|
||||
- Have preference data (chosen/rejected pairs)
|
||||
- Want to use reinforcement learning (PPO, GRPO)
|
||||
- Need reward model training
|
||||
- Doing RLHF (full pipeline)
|
||||
|
||||
**Method selection**:
|
||||
- **SFT**: Have prompt-completion pairs, want basic instruction following
|
||||
- **DPO**: Have preferences, want simple alignment (no reward model needed)
|
||||
- **PPO**: Have reward model, need maximum control over RL
|
||||
- **GRPO**: Memory-constrained, want online RL
|
||||
- **Reward Model**: Building RLHF pipeline, need to score generations
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **HuggingFace Trainer**: Basic fine-tuning without RL
|
||||
- **Axolotl**: YAML-based training configuration
|
||||
- **LitGPT**: Educational, minimal fine-tuning
|
||||
- **Unsloth**: Fast LoRA training
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: OOM during DPO training**
|
||||
|
||||
Reduce batch size and sequence length:
|
||||
```python
|
||||
config = DPOConfig(
|
||||
per_device_train_batch_size=1, # Reduce from 4
|
||||
max_length=512, # Reduce from 1024
|
||||
gradient_accumulation_steps=8 # Maintain effective batch
|
||||
)
|
||||
```
|
||||
|
||||
Or use gradient checkpointing:
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Issue: Poor alignment quality**
|
||||
|
||||
Tune beta parameter:
|
||||
```python
|
||||
# Higher beta = more conservative (stays closer to reference)
|
||||
config = DPOConfig(beta=0.5) # Default 0.1
|
||||
|
||||
# Lower beta = more aggressive alignment
|
||||
config = DPOConfig(beta=0.01)
|
||||
```
|
||||
|
||||
**Issue: Reward model not learning**
|
||||
|
||||
Check loss type and learning rate:
|
||||
```python
|
||||
config = RewardConfig(
|
||||
learning_rate=1e-5, # Try different LR
|
||||
num_train_epochs=3 # Train longer
|
||||
)
|
||||
```
|
||||
|
||||
Ensure preference dataset has clear winners:
|
||||
```python
|
||||
# Verify dataset
|
||||
print(dataset[0])
|
||||
# Should have clear chosen > rejected
|
||||
```
|
||||
|
||||
**Issue: PPO training unstable**
|
||||
|
||||
Adjust KL coefficient:
|
||||
```python
|
||||
config = PPOConfig(
|
||||
kl_coef=0.1, # Increase from 0.05
|
||||
cliprange=0.1 # Reduce from 0.2
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**SFT training guide**: See [references/sft-training.md](references/sft-training.md) for dataset formats, chat templates, packing strategies, and multi-GPU training.
|
||||
|
||||
**DPO variants**: See [references/dpo-variants.md](references/dpo-variants.md) for IPO, cDPO, RPO, and other DPO loss functions with recommended hyperparameters.
|
||||
|
||||
**Reward modeling**: See [references/reward-modeling.md](references/reward-modeling.md) for outcome vs process rewards, Bradley-Terry loss, and reward model evaluation.
|
||||
|
||||
**Online RL methods**: See [references/online-rl.md](references/online-rl.md) for PPO, GRPO, RLOO, and OnlineDPO with detailed configurations.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA required)
|
||||
- **VRAM**: Depends on model and method
|
||||
- SFT 7B: 16GB (with LoRA)
|
||||
- DPO 7B: 24GB (stores reference model)
|
||||
- PPO 7B: 40GB (policy + reward model)
|
||||
- GRPO 7B: 24GB (more memory efficient)
|
||||
- **Multi-GPU**: Supported via `accelerate`
|
||||
- **Mixed precision**: BF16 recommended (A100/H100)
|
||||
|
||||
**Memory optimization**:
|
||||
- Use LoRA/QLoRA for all methods
|
||||
- Enable gradient checkpointing
|
||||
- Use smaller batch sizes with gradient accumulation
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/trl/
|
||||
- GitHub: https://github.com/huggingface/trl
|
||||
- Papers:
|
||||
- "Training language models to follow instructions with human feedback" (InstructGPT, 2022)
|
||||
- "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" (DPO, 2023)
|
||||
- "Group Relative Policy Optimization" (GRPO, 2024)
|
||||
- Examples: https://github.com/huggingface/trl/tree/main/examples/scripts
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,227 +0,0 @@
|
|||
# DPO Variants
|
||||
|
||||
Complete guide to Direct Preference Optimization loss variants in TRL.
|
||||
|
||||
## Overview
|
||||
|
||||
DPO optimizes models using preference data (chosen/rejected pairs). TRL supports 10+ loss variants for different scenarios.
|
||||
|
||||
## Loss Types
|
||||
|
||||
### 1. Sigmoid (Standard DPO)
|
||||
|
||||
**Formula**: `-log(sigmoid(β * logits))`
|
||||
|
||||
**When to use**: Default choice, general preference alignment
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="sigmoid",
|
||||
beta=0.1, # KL penalty
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=1e-6
|
||||
)
|
||||
```
|
||||
|
||||
### 2. IPO (Identity Policy Optimization)
|
||||
|
||||
**Formula**: `(logits - 1/(2β))²`
|
||||
|
||||
**When to use**: Better theoretical foundation, reduce overfitting
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="ipo",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=90,
|
||||
learning_rate=1e-2
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Hinge (SLiC)
|
||||
|
||||
**Formula**: `ReLU(1 - β * logits)`
|
||||
|
||||
**When to use**: Margin-based objective
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="hinge",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=512,
|
||||
learning_rate=1e-4
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Robust DPO
|
||||
|
||||
**Formula**: Sigmoid with label smoothing for noise robustness
|
||||
|
||||
**When to use**: Noisy preference labels
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="robust",
|
||||
beta=0.01,
|
||||
label_smoothing=0.1, # Noise probability
|
||||
per_device_train_batch_size=16,
|
||||
learning_rate=1e-3,
|
||||
max_prompt_length=128,
|
||||
max_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 5. BCO Pair (Binary Classification)
|
||||
|
||||
**Formula**: Train binary classifier (chosen=1, rejected=0)
|
||||
|
||||
**When to use**: Pairwise preference data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="bco_pair",
|
||||
beta=0.01,
|
||||
per_device_train_batch_size=128,
|
||||
learning_rate=5e-7,
|
||||
max_prompt_length=1536,
|
||||
max_completion_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 6. SPPO Hard
|
||||
|
||||
**Formula**: Push chosen→0.5, rejected→-0.5
|
||||
|
||||
**When to use**: Nash equilibrium, sparse data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="sppo_hard",
|
||||
beta=0.1
|
||||
)
|
||||
```
|
||||
|
||||
### 7. DiscoPOP
|
||||
|
||||
**Formula**: Log-Ratio Modulated Loss
|
||||
|
||||
**When to use**: Automated loss discovery
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="discopop",
|
||||
beta=0.05,
|
||||
discopop_tau=0.05,
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=5e-7
|
||||
)
|
||||
```
|
||||
|
||||
### 8. APO Zero
|
||||
|
||||
**Formula**: Increase chosen, decrease rejected likelihood
|
||||
|
||||
**When to use**: Model worse than winning outputs
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="apo_zero",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=2e-7,
|
||||
max_prompt_length=512,
|
||||
max_completion_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 9. APO Down
|
||||
|
||||
**Formula**: Decrease both, emphasize rejected reduction
|
||||
|
||||
**When to use**: Model better than winning outputs
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="apo_down",
|
||||
beta=0.1,
|
||||
# Same hyperparameters as apo_zero
|
||||
)
|
||||
```
|
||||
|
||||
### 10. AOT & AOT Pair
|
||||
|
||||
**Formula**: Distributional alignment via stochastic dominance
|
||||
|
||||
**When to use**:
|
||||
- `aot_pair`: Paired preference data
|
||||
- `aot`: Unpaired data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="aot_pair", # or "aot"
|
||||
beta=0.1,
|
||||
label_smoothing=0.0
|
||||
)
|
||||
```
|
||||
|
||||
## Multi-Loss Training
|
||||
|
||||
Combine multiple losses:
|
||||
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type=["sigmoid", "ipo"],
|
||||
loss_weights=[0.7, 0.3], # Weighted combination
|
||||
beta=0.1
|
||||
)
|
||||
```
|
||||
|
||||
## Key Parameters
|
||||
|
||||
### Beta (β)
|
||||
|
||||
Controls deviation from reference model:
|
||||
- **Higher** (0.5): More conservative, stays close to reference
|
||||
- **Lower** (0.01): More aggressive alignment
|
||||
- **Default**: 0.1
|
||||
|
||||
### Label Smoothing
|
||||
|
||||
For robust DPO:
|
||||
- **0.0**: No smoothing (default)
|
||||
- **0.1-0.3**: Moderate noise robustness
|
||||
- **0.5**: Maximum noise tolerance
|
||||
|
||||
### Max Lengths
|
||||
|
||||
- `max_prompt_length`: 128-1536
|
||||
- `max_completion_length`: 128-512
|
||||
- `max_length`: Total sequence (1024-2048)
|
||||
|
||||
## Comparison Table
|
||||
|
||||
| Loss | Speed | Stability | Best For |
|
||||
|------|-------|-----------|----------|
|
||||
| Sigmoid | Fast | Good | **General use** |
|
||||
| IPO | Fast | Better | Overfitting issues |
|
||||
| Hinge | Fast | Good | Margin objectives |
|
||||
| Robust | Fast | Best | Noisy data |
|
||||
| BCO | Medium | Good | Binary classification |
|
||||
| DiscoPOP | Fast | Good | New architectures |
|
||||
| APO | Fast | Good | Model quality matching |
|
||||
|
||||
## References
|
||||
|
||||
- DPO paper: https://arxiv.org/abs/2305.18290
|
||||
- IPO paper: https://arxiv.org/abs/2310.12036
|
||||
- TRL docs: https://huggingface.co/docs/trl/dpo_trainer
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
# Online RL Methods
|
||||
|
||||
Guide to online reinforcement learning with PPO, GRPO, RLOO, and OnlineDPO.
|
||||
|
||||
## Overview
|
||||
|
||||
Online RL generates completions during training and optimizes based on rewards.
|
||||
|
||||
## PPO (Proximal Policy Optimization)
|
||||
|
||||
Classic RL algorithm for LLM alignment.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
python -m trl.scripts.ppo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--reward_model_path reward-model \
|
||||
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
||||
--output_dir model-ppo \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--total_episodes 10000 \
|
||||
--num_ppo_epochs 4 \
|
||||
--kl_coef 0.05
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `kl_coef`: KL penalty (0.05-0.2)
|
||||
- `num_ppo_epochs`: Epochs per batch (2-4)
|
||||
- `cliprange`: PPO clip (0.1-0.3)
|
||||
- `vf_coef`: Value function coef (0.1)
|
||||
|
||||
## GRPO (Group Relative Policy Optimization)
|
||||
|
||||
Memory-efficient online RL.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Define reward function
|
||||
def reward_func(completions, **kwargs):
|
||||
return [len(set(c.split())) for c in completions]
|
||||
|
||||
config = GRPOConfig(
|
||||
output_dir="model-grpo",
|
||||
num_generations=4, # Completions per prompt
|
||||
max_new_tokens=128
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_func,
|
||||
args=config,
|
||||
train_dataset=load_dataset("trl-lib/tldr", split="train")
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `num_generations`: 2-8 completions
|
||||
- `max_new_tokens`: 64-256
|
||||
- Learning rate: 1e-5 to 1e-4
|
||||
|
||||
## Memory Comparison
|
||||
|
||||
| Method | Memory (7B) | Speed | Use Case |
|
||||
|--------|-------------|-------|----------|
|
||||
| PPO | 40GB | Medium | Maximum control |
|
||||
| GRPO | 24GB | Fast | **Memory-constrained** |
|
||||
| OnlineDPO | 28GB | Fast | No reward model |
|
||||
|
||||
## References
|
||||
|
||||
- PPO paper: https://arxiv.org/abs/1707.06347
|
||||
- GRPO paper: https://arxiv.org/abs/2402.03300
|
||||
- TRL docs: https://huggingface.co/docs/trl/
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
# Reward Modeling
|
||||
|
||||
Guide to training reward models with TRL for RLHF pipelines.
|
||||
|
||||
## Overview
|
||||
|
||||
Reward models score completions based on human preferences. Used in:
|
||||
- PPO training (RL feedback)
|
||||
- GRPO online RL
|
||||
- Completion ranking
|
||||
|
||||
## Basic Training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model (num_labels=1 for single reward score)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
num_labels=1
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
||||
# Load preference dataset (chosen/rejected pairs)
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
# Configure
|
||||
config = RewardConfig(
|
||||
output_dir="Qwen2.5-Reward",
|
||||
per_device_train_batch_size=2,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
Required fields:
|
||||
```json
|
||||
{
|
||||
"prompt": "Question or instruction",
|
||||
"chosen": "Better response",
|
||||
"rejected": "Worse response"
|
||||
}
|
||||
```
|
||||
|
||||
## Bradley-Terry Loss
|
||||
|
||||
Default loss function:
|
||||
```
|
||||
loss = -log(sigmoid(reward_chosen - reward_rejected))
|
||||
```
|
||||
|
||||
Learns to score chosen > rejected.
|
||||
|
||||
## Using Reward Models
|
||||
|
||||
### Inference
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Load trained reward model
|
||||
reward_pipe = pipeline("text-classification", model="Qwen2.5-Reward")
|
||||
|
||||
# Score completions
|
||||
texts = ["Good answer", "Bad answer"]
|
||||
scores = reward_pipe(texts)
|
||||
print(scores) # Higher score = better
|
||||
```
|
||||
|
||||
### In PPO
|
||||
|
||||
```python
|
||||
from trl import PPOTrainer, PPOConfig
|
||||
|
||||
config = PPOConfig(
|
||||
reward_model_path="Qwen2.5-Reward" # Use trained reward model
|
||||
)
|
||||
|
||||
trainer = PPOTrainer(
|
||||
model=policy_model,
|
||||
config=config,
|
||||
# Reward model loaded automatically
|
||||
)
|
||||
```
|
||||
|
||||
## Hyperparameters
|
||||
|
||||
| Model Size | Learning Rate | Batch Size | Epochs |
|
||||
|------------|---------------|------------|--------|
|
||||
| <1B | 2e-5 | 4-8 | 1-2 |
|
||||
| 1-7B | 1e-5 | 2-4 | 1 |
|
||||
| 7-13B | 5e-6 | 1-2 | 1 |
|
||||
|
||||
## Evaluation
|
||||
|
||||
Check reward separation:
|
||||
```python
|
||||
# Chosen should score higher than rejected
|
||||
chosen_rewards = model(**chosen_inputs).logits
|
||||
rejected_rewards = model(**rejected_inputs).logits
|
||||
|
||||
accuracy = (chosen_rewards > rejected_rewards).float().mean()
|
||||
print(f"Accuracy: {accuracy:.2%}") # Target: >80%
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- InstructGPT paper: https://arxiv.org/abs/2203.02155
|
||||
- TRL docs: https://huggingface.co/docs/trl/reward_trainer
|
||||
|
|
@ -1,168 +0,0 @@
|
|||
# SFT Training Guide
|
||||
|
||||
Complete guide to Supervised Fine-Tuning (SFT) with TRL for instruction tuning and task-specific fine-tuning.
|
||||
|
||||
## Overview
|
||||
|
||||
SFT trains models on input-output pairs to minimize cross-entropy loss. Use for:
|
||||
- Instruction following
|
||||
- Task-specific fine-tuning
|
||||
- Chatbot training
|
||||
- Domain adaptation
|
||||
|
||||
## Dataset Formats
|
||||
|
||||
### Format 1: Prompt-Completion
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"completion": "The capital of France is Paris."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Format 2: Conversational (ChatML)
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
{"role": "assistant", "content": "Python is a programming language."}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Format 3: Text-only
|
||||
|
||||
```json
|
||||
[
|
||||
{"text": "User: Hello\nAssistant: Hi! How can I help?"}
|
||||
]
|
||||
```
|
||||
|
||||
## Basic Training
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Configure
|
||||
config = SFTConfig(
|
||||
output_dir="Qwen2.5-SFT",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=2e-5,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Chat Templates
|
||||
|
||||
Apply chat templates automatically:
|
||||
|
||||
```python
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset, # Messages format
|
||||
tokenizer=tokenizer
|
||||
# Chat template applied automatically
|
||||
)
|
||||
```
|
||||
|
||||
Or manually:
|
||||
```python
|
||||
def format_chat(example):
|
||||
messages = example["messages"]
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
return {"text": text}
|
||||
|
||||
dataset = dataset.map(format_chat)
|
||||
```
|
||||
|
||||
## Packing for Efficiency
|
||||
|
||||
Pack multiple sequences into one to maximize GPU utilization:
|
||||
|
||||
```python
|
||||
config = SFTConfig(
|
||||
packing=True, # Enable packing
|
||||
max_seq_length=2048,
|
||||
dataset_text_field="text"
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**: 2-3× faster training
|
||||
**Trade-off**: Slightly more complex batching
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
```bash
|
||||
accelerate launch --num_processes 4 train_sft.py
|
||||
```
|
||||
|
||||
Or with config:
|
||||
```python
|
||||
config = SFTConfig(
|
||||
output_dir="model-sft",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=1
|
||||
)
|
||||
```
|
||||
|
||||
## LoRA Fine-Tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config # Add LoRA
|
||||
)
|
||||
```
|
||||
|
||||
## Hyperparameters
|
||||
|
||||
| Model Size | Learning Rate | Batch Size | Epochs |
|
||||
|------------|---------------|------------|--------|
|
||||
| <1B | 5e-5 | 8-16 | 1-3 |
|
||||
| 1-7B | 2e-5 | 4-8 | 1-2 |
|
||||
| 7-13B | 1e-5 | 2-4 | 1 |
|
||||
| 13B+ | 5e-6 | 1-2 | 1 |
|
||||
|
||||
## References
|
||||
|
||||
- TRL docs: https://huggingface.co/docs/trl/sft_trainer
|
||||
- Examples: https://github.com/huggingface/trl/tree/main/examples/scripts
|
||||
|
|
@ -1,320 +0,0 @@
|
|||
---
|
||||
name: whisper
|
||||
description: OpenAI's general-purpose speech recognition model. Supports 99 languages, transcription, translation to English, and language identification. Six model sizes from tiny (39M params) to large (1550M params). Use for speech-to-text, podcast transcription, or multilingual audio processing. Best for robust, multilingual ASR.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [openai-whisper, transformers, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Whisper, Speech Recognition, ASR, Multimodal, Multilingual, OpenAI, Speech-To-Text, Transcription, Translation, Audio Processing]
|
||||
|
||||
---
|
||||
|
||||
# Whisper - Robust Speech Recognition
|
||||
|
||||
OpenAI's multilingual speech recognition model.
|
||||
|
||||
## When to use Whisper
|
||||
|
||||
**Use when:**
|
||||
- Speech-to-text transcription (99 languages)
|
||||
- Podcast/video transcription
|
||||
- Meeting notes automation
|
||||
- Translation to English
|
||||
- Noisy audio transcription
|
||||
- Multilingual audio processing
|
||||
|
||||
**Metrics**:
|
||||
- **72,900+ GitHub stars**
|
||||
- 99 languages supported
|
||||
- Trained on 680,000 hours of audio
|
||||
- MIT License
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **AssemblyAI**: Managed API, speaker diarization
|
||||
- **Deepgram**: Real-time streaming ASR
|
||||
- **Google Speech-to-Text**: Cloud-based
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Requires Python 3.8-3.11
|
||||
pip install -U openai-whisper
|
||||
|
||||
# Requires ffmpeg
|
||||
# macOS: brew install ffmpeg
|
||||
# Ubuntu: sudo apt install ffmpeg
|
||||
# Windows: choco install ffmpeg
|
||||
```
|
||||
|
||||
### Basic transcription
|
||||
|
||||
```python
|
||||
import whisper
|
||||
|
||||
# Load model
|
||||
model = whisper.load_model("base")
|
||||
|
||||
# Transcribe
|
||||
result = model.transcribe("audio.mp3")
|
||||
|
||||
# Print text
|
||||
print(result["text"])
|
||||
|
||||
# Access segments
|
||||
for segment in result["segments"]:
|
||||
print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] {segment['text']}")
|
||||
```
|
||||
|
||||
## Model sizes
|
||||
|
||||
```python
|
||||
# Available models
|
||||
models = ["tiny", "base", "small", "medium", "large", "turbo"]
|
||||
|
||||
# Load specific model
|
||||
model = whisper.load_model("turbo") # Fastest, good quality
|
||||
```
|
||||
|
||||
| Model | Parameters | English-only | Multilingual | Speed | VRAM |
|
||||
|-------|------------|--------------|--------------|-------|------|
|
||||
| tiny | 39M | ✓ | ✓ | ~32x | ~1 GB |
|
||||
| base | 74M | ✓ | ✓ | ~16x | ~1 GB |
|
||||
| small | 244M | ✓ | ✓ | ~6x | ~2 GB |
|
||||
| medium | 769M | ✓ | ✓ | ~2x | ~5 GB |
|
||||
| large | 1550M | ✗ | ✓ | 1x | ~10 GB |
|
||||
| turbo | 809M | ✗ | ✓ | ~8x | ~6 GB |
|
||||
|
||||
**Recommendation**: Use `turbo` for best speed/quality, `base` for prototyping
|
||||
|
||||
## Transcription options
|
||||
|
||||
### Language specification
|
||||
|
||||
```python
|
||||
# Auto-detect language
|
||||
result = model.transcribe("audio.mp3")
|
||||
|
||||
# Specify language (faster)
|
||||
result = model.transcribe("audio.mp3", language="en")
|
||||
|
||||
# Supported: en, es, fr, de, it, pt, ru, ja, ko, zh, and 89 more
|
||||
```
|
||||
|
||||
### Task selection
|
||||
|
||||
```python
|
||||
# Transcription (default)
|
||||
result = model.transcribe("audio.mp3", task="transcribe")
|
||||
|
||||
# Translation to English
|
||||
result = model.transcribe("spanish.mp3", task="translate")
|
||||
# Input: Spanish audio → Output: English text
|
||||
```
|
||||
|
||||
### Initial prompt
|
||||
|
||||
```python
|
||||
# Improve accuracy with context
|
||||
result = model.transcribe(
|
||||
"audio.mp3",
|
||||
initial_prompt="This is a technical podcast about machine learning and AI."
|
||||
)
|
||||
|
||||
# Helps with:
|
||||
# - Technical terms
|
||||
# - Proper nouns
|
||||
# - Domain-specific vocabulary
|
||||
```
|
||||
|
||||
### Timestamps
|
||||
|
||||
```python
|
||||
# Word-level timestamps
|
||||
result = model.transcribe("audio.mp3", word_timestamps=True)
|
||||
|
||||
for segment in result["segments"]:
|
||||
for word in segment["words"]:
|
||||
print(f"{word['word']} ({word['start']:.2f}s - {word['end']:.2f}s)")
|
||||
```
|
||||
|
||||
### Temperature fallback
|
||||
|
||||
```python
|
||||
# Retry with different temperatures if confidence low
|
||||
result = model.transcribe(
|
||||
"audio.mp3",
|
||||
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
|
||||
)
|
||||
```
|
||||
|
||||
## Command line usage
|
||||
|
||||
```bash
|
||||
# Basic transcription
|
||||
whisper audio.mp3
|
||||
|
||||
# Specify model
|
||||
whisper audio.mp3 --model turbo
|
||||
|
||||
# Output formats
|
||||
whisper audio.mp3 --output_format txt # Plain text
|
||||
whisper audio.mp3 --output_format srt # Subtitles
|
||||
whisper audio.mp3 --output_format vtt # WebVTT
|
||||
whisper audio.mp3 --output_format json # JSON with timestamps
|
||||
|
||||
# Language
|
||||
whisper audio.mp3 --language Spanish
|
||||
|
||||
# Translation
|
||||
whisper spanish.mp3 --task translate
|
||||
```
|
||||
|
||||
## Batch processing
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
audio_files = ["file1.mp3", "file2.mp3", "file3.mp3"]
|
||||
|
||||
for audio_file in audio_files:
|
||||
print(f"Transcribing {audio_file}...")
|
||||
result = model.transcribe(audio_file)
|
||||
|
||||
# Save to file
|
||||
output_file = audio_file.replace(".mp3", ".txt")
|
||||
with open(output_file, "w") as f:
|
||||
f.write(result["text"])
|
||||
```
|
||||
|
||||
## Real-time transcription
|
||||
|
||||
```python
|
||||
# For streaming audio, use faster-whisper
|
||||
# pip install faster-whisper
|
||||
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
model = WhisperModel("base", device="cuda", compute_type="float16")
|
||||
|
||||
# Transcribe with streaming
|
||||
segments, info = model.transcribe("audio.mp3", beam_size=5)
|
||||
|
||||
for segment in segments:
|
||||
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
```python
|
||||
import whisper
|
||||
|
||||
# Automatically uses GPU if available
|
||||
model = whisper.load_model("turbo")
|
||||
|
||||
# Force CPU
|
||||
model = whisper.load_model("turbo", device="cpu")
|
||||
|
||||
# Force GPU
|
||||
model = whisper.load_model("turbo", device="cuda")
|
||||
|
||||
# 10-20× faster on GPU
|
||||
```
|
||||
|
||||
## Integration with other tools
|
||||
|
||||
### Subtitle generation
|
||||
|
||||
```bash
|
||||
# Generate SRT subtitles
|
||||
whisper video.mp4 --output_format srt --language English
|
||||
|
||||
# Output: video.srt
|
||||
```
|
||||
|
||||
### With LangChain
|
||||
|
||||
```python
|
||||
from langchain.document_loaders import WhisperTranscriptionLoader
|
||||
|
||||
loader = WhisperTranscriptionLoader(file_path="audio.mp3")
|
||||
docs = loader.load()
|
||||
|
||||
# Use transcription in RAG
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
vectorstore = Chroma.from_documents(docs, OpenAIEmbeddings())
|
||||
```
|
||||
|
||||
### Extract audio from video
|
||||
|
||||
```bash
|
||||
# Use ffmpeg to extract audio
|
||||
ffmpeg -i video.mp4 -vn -acodec pcm_s16le audio.wav
|
||||
|
||||
# Then transcribe
|
||||
whisper audio.wav
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use turbo model** - Best speed/quality for English
|
||||
2. **Specify language** - Faster than auto-detect
|
||||
3. **Add initial prompt** - Improves technical terms
|
||||
4. **Use GPU** - 10-20× faster
|
||||
5. **Batch process** - More efficient
|
||||
6. **Convert to WAV** - Better compatibility
|
||||
7. **Split long audio** - <30 min chunks
|
||||
8. **Check language support** - Quality varies by language
|
||||
9. **Use faster-whisper** - 4× faster than openai-whisper
|
||||
10. **Monitor VRAM** - Scale model size to hardware
|
||||
|
||||
## Performance
|
||||
|
||||
| Model | Real-time factor (CPU) | Real-time factor (GPU) |
|
||||
|-------|------------------------|------------------------|
|
||||
| tiny | ~0.32 | ~0.01 |
|
||||
| base | ~0.16 | ~0.01 |
|
||||
| turbo | ~0.08 | ~0.01 |
|
||||
| large | ~1.0 | ~0.05 |
|
||||
|
||||
*Real-time factor: 0.1 = 10× faster than real-time*
|
||||
|
||||
## Language support
|
||||
|
||||
Top-supported languages:
|
||||
- English (en)
|
||||
- Spanish (es)
|
||||
- French (fr)
|
||||
- German (de)
|
||||
- Italian (it)
|
||||
- Portuguese (pt)
|
||||
- Russian (ru)
|
||||
- Japanese (ja)
|
||||
- Korean (ko)
|
||||
- Chinese (zh)
|
||||
|
||||
Full list: 99 languages total
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Hallucinations** - May repeat or invent text
|
||||
2. **Long-form accuracy** - Degrades on >30 min audio
|
||||
3. **Speaker identification** - No diarization
|
||||
4. **Accents** - Quality varies
|
||||
5. **Background noise** - Can affect accuracy
|
||||
6. **Real-time latency** - Not suitable for live captioning
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/openai/whisper ⭐ 72,900+
|
||||
- **Paper**: https://arxiv.org/abs/2212.04356
|
||||
- **Model Card**: https://github.com/openai/whisper/blob/main/model-card.md
|
||||
- **Colab**: Available in repo
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
|
|
@ -1,189 +0,0 @@
|
|||
# Whisper Language Support Guide
|
||||
|
||||
Complete guide to Whisper's multilingual capabilities.
|
||||
|
||||
## Supported languages (99 total)
|
||||
|
||||
### Top-tier support (WER < 10%)
|
||||
|
||||
- English (en)
|
||||
- Spanish (es)
|
||||
- French (fr)
|
||||
- German (de)
|
||||
- Italian (it)
|
||||
- Portuguese (pt)
|
||||
- Dutch (nl)
|
||||
- Polish (pl)
|
||||
- Russian (ru)
|
||||
- Japanese (ja)
|
||||
- Korean (ko)
|
||||
- Chinese (zh)
|
||||
|
||||
### Good support (WER 10-20%)
|
||||
|
||||
- Arabic (ar)
|
||||
- Turkish (tr)
|
||||
- Vietnamese (vi)
|
||||
- Swedish (sv)
|
||||
- Finnish (fi)
|
||||
- Czech (cs)
|
||||
- Romanian (ro)
|
||||
- Hungarian (hu)
|
||||
- Danish (da)
|
||||
- Norwegian (no)
|
||||
- Thai (th)
|
||||
- Hebrew (he)
|
||||
- Greek (el)
|
||||
- Indonesian (id)
|
||||
- Malay (ms)
|
||||
|
||||
### Full list (99 languages)
|
||||
|
||||
Afrikaans, Albanian, Amharic, Arabic, Armenian, Assamese, Azerbaijani, Bashkir, Basque, Belarusian, Bengali, Bosnian, Breton, Bulgarian, Burmese, Cantonese, Catalan, Chinese, Croatian, Czech, Danish, Dutch, English, Estonian, Faroese, Finnish, French, Galician, Georgian, German, Greek, Gujarati, Haitian Creole, Hausa, Hawaiian, Hebrew, Hindi, Hungarian, Icelandic, Indonesian, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Lao, Latin, Latvian, Lingala, Lithuanian, Luxembourgish, Macedonian, Malagasy, Malay, Malayalam, Maltese, Maori, Marathi, Moldavian, Mongolian, Myanmar, Nepali, Norwegian, Nynorsk, Occitan, Pashto, Persian, Polish, Portuguese, Punjabi, Pushto, Romanian, Russian, Sanskrit, Serbian, Shona, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tagalog, Tajik, Tamil, Tatar, Telugu, Thai, Tibetan, Turkish, Turkmen, Ukrainian, Urdu, Uzbek, Vietnamese, Welsh, Yiddish, Yoruba
|
||||
|
||||
## Usage examples
|
||||
|
||||
### Auto-detect language
|
||||
|
||||
```python
|
||||
import whisper
|
||||
|
||||
model = whisper.load_model("turbo")
|
||||
|
||||
# Auto-detect language
|
||||
result = model.transcribe("audio.mp3")
|
||||
|
||||
print(f"Detected language: {result['language']}")
|
||||
print(f"Text: {result['text']}")
|
||||
```
|
||||
|
||||
### Specify language (faster)
|
||||
|
||||
```python
|
||||
# Specify language for faster transcription
|
||||
result = model.transcribe("audio.mp3", language="es") # Spanish
|
||||
result = model.transcribe("audio.mp3", language="fr") # French
|
||||
result = model.transcribe("audio.mp3", language="ja") # Japanese
|
||||
```
|
||||
|
||||
### Translation to English
|
||||
|
||||
```python
|
||||
# Translate any language to English
|
||||
result = model.transcribe(
|
||||
"spanish_audio.mp3",
|
||||
task="translate" # Translates to English
|
||||
)
|
||||
|
||||
print(f"Original language: {result['language']}")
|
||||
print(f"English translation: {result['text']}")
|
||||
```
|
||||
|
||||
## Language-specific tips
|
||||
|
||||
### Chinese
|
||||
|
||||
```python
|
||||
# Chinese works well with larger models
|
||||
model = whisper.load_model("large")
|
||||
|
||||
result = model.transcribe(
|
||||
"chinese_audio.mp3",
|
||||
language="zh",
|
||||
initial_prompt="这是一段关于技术的讨论" # Context helps
|
||||
)
|
||||
```
|
||||
|
||||
### Japanese
|
||||
|
||||
```python
|
||||
# Japanese benefits from initial prompt
|
||||
result = model.transcribe(
|
||||
"japanese_audio.mp3",
|
||||
language="ja",
|
||||
initial_prompt="これは技術的な会議の録音です"
|
||||
)
|
||||
```
|
||||
|
||||
### Arabic
|
||||
|
||||
```python
|
||||
# Arabic: Use large model for best results
|
||||
model = whisper.load_model("large")
|
||||
|
||||
result = model.transcribe(
|
||||
"arabic_audio.mp3",
|
||||
language="ar"
|
||||
)
|
||||
```
|
||||
|
||||
## Model size recommendations
|
||||
|
||||
| Language Tier | Recommended Model | WER |
|
||||
|---------------|-------------------|-----|
|
||||
| Top-tier (en, es, fr, de) | base/turbo | < 10% |
|
||||
| Good (ar, tr, vi) | medium/large | 10-20% |
|
||||
| Lower-resource | large | 20-30% |
|
||||
|
||||
## Performance by language
|
||||
|
||||
### English
|
||||
|
||||
- **tiny**: WER ~15%
|
||||
- **base**: WER ~8%
|
||||
- **small**: WER ~5%
|
||||
- **medium**: WER ~4%
|
||||
- **large**: WER ~3%
|
||||
- **turbo**: WER ~3.5%
|
||||
|
||||
### Spanish
|
||||
|
||||
- **tiny**: WER ~20%
|
||||
- **base**: WER ~12%
|
||||
- **medium**: WER ~6%
|
||||
- **large**: WER ~4%
|
||||
|
||||
### Chinese
|
||||
|
||||
- **small**: WER ~15%
|
||||
- **medium**: WER ~8%
|
||||
- **large**: WER ~5%
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use English-only models** - Better for small models (tiny/base)
|
||||
2. **Specify language** - Faster than auto-detect
|
||||
3. **Add initial prompt** - Improves accuracy for technical terms
|
||||
4. **Use larger models** - For low-resource languages
|
||||
5. **Test on sample** - Quality varies by accent/dialect
|
||||
6. **Consider audio quality** - Clear audio = better results
|
||||
7. **Check language codes** - Use ISO 639-1 codes (2 letters)
|
||||
|
||||
## Language detection
|
||||
|
||||
```python
|
||||
# Detect language only (no transcription)
|
||||
import whisper
|
||||
|
||||
model = whisper.load_model("base")
|
||||
|
||||
# Load audio
|
||||
audio = whisper.load_audio("audio.mp3")
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
|
||||
# Make log-Mel spectrogram
|
||||
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
||||
|
||||
# Detect language
|
||||
_, probs = model.detect_language(mel)
|
||||
detected_language = max(probs, key=probs.get)
|
||||
|
||||
print(f"Detected language: {detected_language}")
|
||||
print(f"Confidence: {probs[detected_language]:.2%}")
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Paper**: https://arxiv.org/abs/2212.04356
|
||||
- **GitHub**: https://github.com/openai/whisper
|
||||
- **Model Card**: https://github.com/openai/whisper/blob/main/model-card.md
|
||||
Loading…
Add table
Add a link
Reference in a new issue