The architecture has been updated

This commit is contained in:
Skyber_2 2026-03-31 23:31:36 +03:00
parent 805f7a017e
commit a01257ead9
1119 changed files with 226 additions and 352 deletions

View file

@ -0,0 +1,3 @@
---
description: Knowledge and Tools for Machine Learning Operations - tools and frameworks for training, fine-tuning, deploying, and optimizing ML/AI models
---

View file

@ -0,0 +1,3 @@
---
description: GPU cloud providers and serverless compute platforms for ML workloads.
---

View file

@ -0,0 +1,548 @@
---
name: lambda-labs-gpu-cloud
description: Reserved and on-demand GPU cloud instances for ML training and inference. Use when you need dedicated GPU instances with simple SSH access, persistent filesystems, or high-performance multi-node clusters for large-scale training.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [lambda-cloud-client>=1.0.0]
metadata:
hermes:
tags: [Infrastructure, GPU Cloud, Training, Inference, Lambda Labs]
---
# Lambda Labs GPU Cloud
Comprehensive guide to running ML workloads on Lambda Labs GPU cloud with on-demand instances and 1-Click Clusters.
## When to use Lambda Labs
**Use Lambda Labs when:**
- Need dedicated GPU instances with full SSH access
- Running long training jobs (hours to days)
- Want simple pricing with no egress fees
- Need persistent storage across sessions
- Require high-performance multi-node clusters (16-512 GPUs)
- Want pre-installed ML stack (Lambda Stack with PyTorch, CUDA, NCCL)
**Key features:**
- **GPU variety**: B200, H100, GH200, A100, A10, A6000, V100
- **Lambda Stack**: Pre-installed PyTorch, TensorFlow, CUDA, cuDNN, NCCL
- **Persistent filesystems**: Keep data across instance restarts
- **1-Click Clusters**: 16-512 GPU Slurm clusters with InfiniBand
- **Simple pricing**: Pay-per-minute, no egress fees
- **Global regions**: 12+ regions worldwide
**Use alternatives instead:**
- **Modal**: For serverless, auto-scaling workloads
- **SkyPilot**: For multi-cloud orchestration and cost optimization
- **RunPod**: For cheaper spot instances and serverless endpoints
- **Vast.ai**: For GPU marketplace with lowest prices
## Quick start
### Account setup
1. Create account at https://lambda.ai
2. Add payment method
3. Generate API key from dashboard
4. Add SSH key (required before launching instances)
### Launch via console
1. Go to https://cloud.lambda.ai/instances
2. Click "Launch instance"
3. Select GPU type and region
4. Choose SSH key
5. Optionally attach filesystem
6. Launch and wait 3-15 minutes
### Connect via SSH
```bash
# Get instance IP from console
ssh ubuntu@<INSTANCE-IP>
# Or with specific key
ssh -i ~/.ssh/lambda_key ubuntu@<INSTANCE-IP>
```
## GPU instances
### Available GPUs
| GPU | VRAM | Price/GPU/hr | Best For |
|-----|------|--------------|----------|
| B200 SXM6 | 180 GB | $4.99 | Largest models, fastest training |
| H100 SXM | 80 GB | $2.99-3.29 | Large model training |
| H100 PCIe | 80 GB | $2.49 | Cost-effective H100 |
| GH200 | 96 GB | $1.49 | Single-GPU large models |
| A100 80GB | 80 GB | $1.79 | Production training |
| A100 40GB | 40 GB | $1.29 | Standard training |
| A10 | 24 GB | $0.75 | Inference, fine-tuning |
| A6000 | 48 GB | $0.80 | Good VRAM/price ratio |
| V100 | 16 GB | $0.55 | Budget training |
### Instance configurations
```
8x GPU: Best for distributed training (DDP, FSDP)
4x GPU: Large models, multi-GPU training
2x GPU: Medium workloads
1x GPU: Fine-tuning, inference, development
```
### Launch times
- Single-GPU: 3-5 minutes
- Multi-GPU: 10-15 minutes
## Lambda Stack
All instances come with Lambda Stack pre-installed:
```bash
# Included software
- Ubuntu 22.04 LTS
- NVIDIA drivers (latest)
- CUDA 12.x
- cuDNN 8.x
- NCCL (for multi-GPU)
- PyTorch (latest)
- TensorFlow (latest)
- JAX
- JupyterLab
```
### Verify installation
```bash
# Check GPU
nvidia-smi
# Check PyTorch
python -c "import torch; print(torch.cuda.is_available())"
# Check CUDA version
nvcc --version
```
## Python API
### Installation
```bash
pip install lambda-cloud-client
```
### Authentication
```python
import os
import lambda_cloud_client
# Configure with API key
configuration = lambda_cloud_client.Configuration(
host="https://cloud.lambdalabs.com/api/v1",
access_token=os.environ["LAMBDA_API_KEY"]
)
```
### List available instances
```python
with lambda_cloud_client.ApiClient(configuration) as api_client:
api = lambda_cloud_client.DefaultApi(api_client)
# Get available instance types
types = api.instance_types()
for name, info in types.data.items():
print(f"{name}: {info.instance_type.description}")
```
### Launch instance
```python
from lambda_cloud_client.models import LaunchInstanceRequest
request = LaunchInstanceRequest(
region_name="us-west-1",
instance_type_name="gpu_1x_h100_sxm5",
ssh_key_names=["my-ssh-key"],
file_system_names=["my-filesystem"], # Optional
name="training-job"
)
response = api.launch_instance(request)
instance_id = response.data.instance_ids[0]
print(f"Launched: {instance_id}")
```
### List running instances
```python
instances = api.list_instances()
for instance in instances.data:
print(f"{instance.name}: {instance.ip} ({instance.status})")
```
### Terminate instance
```python
from lambda_cloud_client.models import TerminateInstanceRequest
request = TerminateInstanceRequest(
instance_ids=[instance_id]
)
api.terminate_instance(request)
```
### SSH key management
```python
from lambda_cloud_client.models import AddSshKeyRequest
# Add SSH key
request = AddSshKeyRequest(
name="my-key",
public_key="ssh-rsa AAAA..."
)
api.add_ssh_key(request)
# List keys
keys = api.list_ssh_keys()
# Delete key
api.delete_ssh_key(key_id)
```
## CLI with curl
### List instance types
```bash
curl -u $LAMBDA_API_KEY: \
https://cloud.lambdalabs.com/api/v1/instance-types | jq
```
### Launch instance
```bash
curl -u $LAMBDA_API_KEY: \
-X POST https://cloud.lambdalabs.com/api/v1/instance-operations/launch \
-H "Content-Type: application/json" \
-d '{
"region_name": "us-west-1",
"instance_type_name": "gpu_1x_h100_sxm5",
"ssh_key_names": ["my-key"]
}' | jq
```
### Terminate instance
```bash
curl -u $LAMBDA_API_KEY: \
-X POST https://cloud.lambdalabs.com/api/v1/instance-operations/terminate \
-H "Content-Type: application/json" \
-d '{"instance_ids": ["<INSTANCE-ID>"]}' | jq
```
## Persistent storage
### Filesystems
Filesystems persist data across instance restarts:
```bash
# Mount location
/lambda/nfs/<FILESYSTEM_NAME>
# Example: save checkpoints
python train.py --checkpoint-dir /lambda/nfs/my-storage/checkpoints
```
### Create filesystem
1. Go to Storage in Lambda console
2. Click "Create filesystem"
3. Select region (must match instance region)
4. Name and create
### Attach to instance
Filesystems must be attached at instance launch time:
- Via console: Select filesystem when launching
- Via API: Include `file_system_names` in launch request
### Best practices
```bash
# Store on filesystem (persists)
/lambda/nfs/storage/
├── datasets/
├── checkpoints/
├── models/
└── outputs/
# Local SSD (faster, ephemeral)
/home/ubuntu/
└── working/ # Temporary files
```
## SSH configuration
### Add SSH key
```bash
# Generate key locally
ssh-keygen -t ed25519 -f ~/.ssh/lambda_key
# Add public key to Lambda console
# Or via API
```
### Multiple keys
```bash
# On instance, add more keys
echo 'ssh-rsa AAAA...' >> ~/.ssh/authorized_keys
```
### Import from GitHub
```bash
# On instance
ssh-import-id gh:username
```
### SSH tunneling
```bash
# Forward Jupyter
ssh -L 8888:localhost:8888 ubuntu@<IP>
# Forward TensorBoard
ssh -L 6006:localhost:6006 ubuntu@<IP>
# Multiple ports
ssh -L 8888:localhost:8888 -L 6006:localhost:6006 ubuntu@<IP>
```
## JupyterLab
### Launch from console
1. Go to Instances page
2. Click "Launch" in Cloud IDE column
3. JupyterLab opens in browser
### Manual access
```bash
# On instance
jupyter lab --ip=0.0.0.0 --port=8888
# From local machine with tunnel
ssh -L 8888:localhost:8888 ubuntu@<IP>
# Open http://localhost:8888
```
## Training workflows
### Single-GPU training
```bash
# SSH to instance
ssh ubuntu@<IP>
# Clone repo
git clone https://github.com/user/project
cd project
# Install dependencies
pip install -r requirements.txt
# Train
python train.py --epochs 100 --checkpoint-dir /lambda/nfs/storage/checkpoints
```
### Multi-GPU training (single node)
```python
# train_ddp.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main():
dist.init_process_group("nccl")
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
model = MyModel().to(device)
model = DDP(model, device_ids=[device])
# Training loop...
if __name__ == "__main__":
main()
```
```bash
# Launch with torchrun (8 GPUs)
torchrun --nproc_per_node=8 train_ddp.py
```
### Checkpoint to filesystem
```python
import os
checkpoint_dir = "/lambda/nfs/my-storage/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
# Save checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, f"{checkpoint_dir}/checkpoint_{epoch}.pt")
```
## 1-Click Clusters
### Overview
High-performance Slurm clusters with:
- 16-512 NVIDIA H100 or B200 GPUs
- NVIDIA Quantum-2 400 Gb/s InfiniBand
- GPUDirect RDMA at 3200 Gb/s
- Pre-installed distributed ML stack
### Included software
- Ubuntu 22.04 LTS + Lambda Stack
- NCCL, Open MPI
- PyTorch with DDP and FSDP
- TensorFlow
- OFED drivers
### Storage
- 24 TB NVMe per compute node (ephemeral)
- Lambda filesystems for persistent data
### Multi-node training
```bash
# On Slurm cluster
srun --nodes=4 --ntasks-per-node=8 --gpus-per-node=8 \
torchrun --nnodes=4 --nproc_per_node=8 \
--rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29500 \
train.py
```
## Networking
### Bandwidth
- Inter-instance (same region): up to 200 Gbps
- Internet outbound: 20 Gbps max
### Firewall
- Default: Only port 22 (SSH) open
- Configure additional ports in Lambda console
- ICMP traffic allowed by default
### Private IPs
```bash
# Find private IP
ip addr show | grep 'inet '
```
## Common workflows
### Workflow 1: Fine-tuning LLM
```bash
# 1. Launch 8x H100 instance with filesystem
# 2. SSH and setup
ssh ubuntu@<IP>
pip install transformers accelerate peft
# 3. Download model to filesystem
python -c "
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
model.save_pretrained('/lambda/nfs/storage/models/llama-2-7b')
"
# 4. Fine-tune with checkpoints on filesystem
accelerate launch --num_processes 8 train.py \
--model_path /lambda/nfs/storage/models/llama-2-7b \
--output_dir /lambda/nfs/storage/outputs \
--checkpoint_dir /lambda/nfs/storage/checkpoints
```
### Workflow 2: Batch inference
```bash
# 1. Launch A10 instance (cost-effective for inference)
# 2. Run inference
python inference.py \
--model /lambda/nfs/storage/models/fine-tuned \
--input /lambda/nfs/storage/data/inputs.jsonl \
--output /lambda/nfs/storage/data/outputs.jsonl
```
## Cost optimization
### Choose right GPU
| Task | Recommended GPU |
|------|-----------------|
| LLM fine-tuning (7B) | A100 40GB |
| LLM fine-tuning (70B) | 8x H100 |
| Inference | A10, A6000 |
| Development | V100, A10 |
| Maximum performance | B200 |
### Reduce costs
1. **Use filesystems**: Avoid re-downloading data
2. **Checkpoint frequently**: Resume interrupted training
3. **Right-size**: Don't over-provision GPUs
4. **Terminate idle**: No auto-stop, manually terminate
### Monitor usage
- Dashboard shows real-time GPU utilization
- API for programmatic monitoring
## Common issues
| Issue | Solution |
|-------|----------|
| Instance won't launch | Check region availability, try different GPU |
| SSH connection refused | Wait for instance to initialize (3-15 min) |
| Data lost after terminate | Use persistent filesystems |
| Slow data transfer | Use filesystem in same region |
| GPU not detected | Reboot instance, check drivers |
## References
- **[Advanced Usage](references/advanced-usage.md)** - Multi-node training, API automation
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
## Resources
- **Documentation**: https://docs.lambda.ai
- **Console**: https://cloud.lambda.ai
- **Pricing**: https://lambda.ai/instances
- **Support**: https://support.lambdalabs.com
- **Blog**: https://lambda.ai/blog

View file

@ -0,0 +1,611 @@
# Lambda Labs Advanced Usage Guide
## Multi-Node Distributed Training
### PyTorch DDP across nodes
```python
# train_multi_node.py
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_distributed():
# Environment variables set by launcher
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
dist.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size
)
torch.cuda.set_device(local_rank)
return rank, world_size, local_rank
def main():
rank, world_size, local_rank = setup_distributed()
model = MyModel().cuda(local_rank)
model = DDP(model, device_ids=[local_rank])
# Training loop with synchronized gradients
for epoch in range(num_epochs):
train_one_epoch(model, dataloader)
# Save checkpoint on rank 0 only
if rank == 0:
torch.save(model.module.state_dict(), f"checkpoint_{epoch}.pt")
dist.destroy_process_group()
if __name__ == "__main__":
main()
```
### Launch on multiple instances
```bash
# On Node 0 (master)
export MASTER_ADDR=<NODE0_PRIVATE_IP>
export MASTER_PORT=29500
torchrun \
--nnodes=2 \
--nproc_per_node=8 \
--node_rank=0 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train_multi_node.py
# On Node 1
export MASTER_ADDR=<NODE0_PRIVATE_IP>
export MASTER_PORT=29500
torchrun \
--nnodes=2 \
--nproc_per_node=8 \
--node_rank=1 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train_multi_node.py
```
### FSDP for large models
```python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
# Wrap policy for transformer models
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
device_id=local_rank,
)
```
### DeepSpeed ZeRO
```python
# ds_config.json
{
"train_batch_size": 64,
"gradient_accumulation_steps": 4,
"fp16": {"enabled": true},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "cpu"}
}
}
```
```bash
# Launch with DeepSpeed
deepspeed --num_nodes=2 \
--num_gpus=8 \
--hostfile=hostfile.txt \
train.py --deepspeed ds_config.json
```
### Hostfile for multi-node
```bash
# hostfile.txt
node0_ip slots=8
node1_ip slots=8
```
## API Automation
### Auto-launch training jobs
```python
import os
import time
import lambda_cloud_client
from lambda_cloud_client.models import LaunchInstanceRequest
class LambdaJobManager:
def __init__(self, api_key: str):
self.config = lambda_cloud_client.Configuration(
host="https://cloud.lambdalabs.com/api/v1",
access_token=api_key
)
def find_available_gpu(self, gpu_types: list[str], regions: list[str] = None):
"""Find first available GPU type across regions."""
with lambda_cloud_client.ApiClient(self.config) as client:
api = lambda_cloud_client.DefaultApi(client)
types = api.instance_types()
for gpu_type in gpu_types:
if gpu_type in types.data:
info = types.data[gpu_type]
for region in info.regions_with_capacity_available:
if regions is None or region.name in regions:
return gpu_type, region.name
return None, None
def launch_and_wait(self, instance_type: str, region: str,
ssh_key: str, filesystem: str = None,
timeout: int = 900) -> dict:
"""Launch instance and wait for it to be ready."""
with lambda_cloud_client.ApiClient(self.config) as client:
api = lambda_cloud_client.DefaultApi(client)
request = LaunchInstanceRequest(
region_name=region,
instance_type_name=instance_type,
ssh_key_names=[ssh_key],
file_system_names=[filesystem] if filesystem else [],
)
response = api.launch_instance(request)
instance_id = response.data.instance_ids[0]
# Poll until ready
start = time.time()
while time.time() - start < timeout:
instance = api.get_instance(instance_id)
if instance.data.status == "active":
return {
"id": instance_id,
"ip": instance.data.ip,
"status": "active"
}
time.sleep(30)
raise TimeoutError(f"Instance {instance_id} not ready after {timeout}s")
def terminate(self, instance_ids: list[str]):
"""Terminate instances."""
from lambda_cloud_client.models import TerminateInstanceRequest
with lambda_cloud_client.ApiClient(self.config) as client:
api = lambda_cloud_client.DefaultApi(client)
request = TerminateInstanceRequest(instance_ids=instance_ids)
api.terminate_instance(request)
# Usage
manager = LambdaJobManager(os.environ["LAMBDA_API_KEY"])
# Find available H100 or A100
gpu_type, region = manager.find_available_gpu(
["gpu_8x_h100_sxm5", "gpu_8x_a100_80gb_sxm4"],
regions=["us-west-1", "us-east-1"]
)
if gpu_type:
instance = manager.launch_and_wait(
gpu_type, region,
ssh_key="my-key",
filesystem="training-data"
)
print(f"Ready: ssh ubuntu@{instance['ip']}")
```
### Batch job submission
```python
import subprocess
import paramiko
def run_remote_job(ip: str, ssh_key_path: str, commands: list[str]):
"""Execute commands on remote instance."""
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(ip, username="ubuntu", key_filename=ssh_key_path)
for cmd in commands:
stdin, stdout, stderr = client.exec_command(cmd)
print(stdout.read().decode())
if stderr.read():
print(f"Error: {stderr.read().decode()}")
client.close()
# Submit training job
commands = [
"cd /lambda/nfs/storage/project",
"git pull",
"pip install -r requirements.txt",
"nohup torchrun --nproc_per_node=8 train.py > train.log 2>&1 &"
]
run_remote_job(instance["ip"], "~/.ssh/lambda_key", commands)
```
### Monitor training progress
```python
def monitor_job(ip: str, ssh_key_path: str, log_file: str = "train.log"):
"""Stream training logs from remote instance."""
import time
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(ip, username="ubuntu", key_filename=ssh_key_path)
# Tail log file
stdin, stdout, stderr = client.exec_command(f"tail -f {log_file}")
try:
for line in stdout:
print(line.strip())
except KeyboardInterrupt:
pass
finally:
client.close()
```
## 1-Click Cluster Workflows
### Slurm job submission
```bash
#!/bin/bash
#SBATCH --job-name=llm-training
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8
#SBATCH --time=24:00:00
#SBATCH --output=logs/%j.out
#SBATCH --error=logs/%j.err
# Set up distributed environment
export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export MASTER_PORT=29500
# Launch training
srun torchrun \
--nnodes=$SLURM_NNODES \
--nproc_per_node=$SLURM_GPUS_PER_NODE \
--rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
train.py \
--config config.yaml
```
### Interactive cluster session
```bash
# Request interactive session
srun --nodes=1 --ntasks=1 --gpus=8 --time=4:00:00 --pty bash
# Now on compute node with 8 GPUs
nvidia-smi
python train.py
```
### Monitoring cluster jobs
```bash
# View job queue
squeue
# View job details
scontrol show job <JOB_ID>
# Cancel job
scancel <JOB_ID>
# View node status
sinfo
# View GPU usage across cluster
srun --nodes=4 nvidia-smi --query-gpu=name,utilization.gpu --format=csv
```
## Advanced Filesystem Usage
### Data staging workflow
```bash
# Stage data from S3 to filesystem (one-time)
aws s3 sync s3://my-bucket/dataset /lambda/nfs/storage/datasets/
# Or use rclone
rclone sync s3:my-bucket/dataset /lambda/nfs/storage/datasets/
```
### Shared filesystem across instances
```python
# Instance 1: Write checkpoints
checkpoint_path = "/lambda/nfs/shared/checkpoints/model_step_1000.pt"
torch.save(model.state_dict(), checkpoint_path)
# Instance 2: Read checkpoints
model.load_state_dict(torch.load(checkpoint_path))
```
### Filesystem best practices
```bash
# Organize for ML workflows
/lambda/nfs/storage/
├── datasets/
│ ├── raw/ # Original data
│ └── processed/ # Preprocessed data
├── models/
│ ├── pretrained/ # Base models
│ └── fine-tuned/ # Your trained models
├── checkpoints/
│ └── experiment_1/ # Per-experiment checkpoints
├── logs/
│ └── tensorboard/ # Training logs
└── outputs/
└── inference/ # Inference results
```
## Environment Management
### Custom Python environments
```bash
# Don't modify system Python, create venv
python -m venv ~/myenv
source ~/myenv/bin/activate
# Install packages
pip install torch transformers accelerate
# Save to filesystem for reuse
cp -r ~/myenv /lambda/nfs/storage/envs/myenv
```
### Conda environments
```bash
# Install miniconda (if not present)
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh -b -p ~/miniconda3
# Create environment
~/miniconda3/bin/conda create -n ml python=3.10 pytorch pytorch-cuda=12.1 -c pytorch -c nvidia -y
# Activate
source ~/miniconda3/bin/activate ml
```
### Docker containers
```bash
# Pull and run NVIDIA container
docker run --gpus all -it --rm \
-v /lambda/nfs/storage:/data \
nvcr.io/nvidia/pytorch:24.01-py3
# Run training in container
docker run --gpus all -d \
-v /lambda/nfs/storage:/data \
-v $(pwd):/workspace \
nvcr.io/nvidia/pytorch:24.01-py3 \
python /workspace/train.py
```
## Monitoring and Observability
### GPU monitoring
```bash
# Real-time GPU stats
watch -n 1 nvidia-smi
# GPU utilization over time
nvidia-smi dmon -s u -d 1
# Detailed GPU info
nvidia-smi -q
```
### System monitoring
```bash
# CPU and memory
htop
# Disk I/O
iostat -x 1
# Network
iftop
# All resources
glances
```
### TensorBoard integration
```bash
# Start TensorBoard
tensorboard --logdir /lambda/nfs/storage/logs --port 6006 --bind_all
# SSH tunnel from local machine
ssh -L 6006:localhost:6006 ubuntu@<IP>
# Access at http://localhost:6006
```
### Weights & Biases integration
```python
import wandb
# Initialize with API key
wandb.login(key=os.environ["WANDB_API_KEY"])
# Start run
wandb.init(
project="lambda-training",
config={"learning_rate": 1e-4, "epochs": 100}
)
# Log metrics
wandb.log({"loss": loss, "accuracy": acc})
# Save artifacts to filesystem + W&B
wandb.save("/lambda/nfs/storage/checkpoints/best_model.pt")
```
## Cost Optimization Strategies
### Checkpointing for interruption recovery
```python
import os
def save_checkpoint(model, optimizer, epoch, loss, path):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, path)
def load_checkpoint(path, model, optimizer):
if os.path.exists(path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'], checkpoint['loss']
return 0, float('inf')
# Save every N steps to filesystem
checkpoint_path = "/lambda/nfs/storage/checkpoints/latest.pt"
if step % 1000 == 0:
save_checkpoint(model, optimizer, epoch, loss, checkpoint_path)
```
### Instance selection by workload
```python
def recommend_instance(model_params: int, batch_size: int, task: str) -> str:
"""Recommend Lambda instance based on workload."""
if task == "inference":
if model_params < 7e9:
return "gpu_1x_a10" # $0.75/hr
elif model_params < 13e9:
return "gpu_1x_a6000" # $0.80/hr
else:
return "gpu_1x_h100_pcie" # $2.49/hr
elif task == "fine-tuning":
if model_params < 7e9:
return "gpu_1x_a100" # $1.29/hr
elif model_params < 13e9:
return "gpu_4x_a100" # $5.16/hr
else:
return "gpu_8x_h100_sxm5" # $23.92/hr
elif task == "pretraining":
return "gpu_8x_h100_sxm5" # Maximum performance
return "gpu_1x_a100" # Default
```
### Auto-terminate idle instances
```python
import time
from datetime import datetime, timedelta
def auto_terminate_idle(api_key: str, idle_threshold_hours: float = 2):
"""Terminate instances idle for too long."""
manager = LambdaJobManager(api_key)
with lambda_cloud_client.ApiClient(manager.config) as client:
api = lambda_cloud_client.DefaultApi(client)
instances = api.list_instances()
for instance in instances.data:
# Check if instance has been running without activity
# (You'd need to track this separately)
launch_time = instance.launched_at
if datetime.now() - launch_time > timedelta(hours=idle_threshold_hours):
print(f"Terminating idle instance: {instance.id}")
manager.terminate([instance.id])
```
## Security Best Practices
### SSH key rotation
```bash
# Generate new key pair
ssh-keygen -t ed25519 -f ~/.ssh/lambda_key_new -C "lambda-$(date +%Y%m)"
# Add new key via Lambda console or API
# Update authorized_keys on running instances
ssh ubuntu@<IP> "echo '$(cat ~/.ssh/lambda_key_new.pub)' >> ~/.ssh/authorized_keys"
# Test new key
ssh -i ~/.ssh/lambda_key_new ubuntu@<IP>
# Remove old key from Lambda console
```
### Firewall configuration
```bash
# Lambda console: Only open necessary ports
# Recommended:
# - 22 (SSH) - Always needed
# - 6006 (TensorBoard) - If using
# - 8888 (Jupyter) - If using
# - 29500 (PyTorch distributed) - For multi-node only
```
### Secrets management
```bash
# Don't hardcode API keys in code
# Use environment variables
export HF_TOKEN="hf_..."
export WANDB_API_KEY="..."
# Or use .env file (add to .gitignore)
source .env
# On instance, store in ~/.bashrc
echo 'export HF_TOKEN="..."' >> ~/.bashrc
```

View file

@ -0,0 +1,530 @@
# Lambda Labs Troubleshooting Guide
## Instance Launch Issues
### No instances available
**Error**: "No capacity available" or instance type not listed
**Solutions**:
```bash
# Check availability via API
curl -u $LAMBDA_API_KEY: \
https://cloud.lambdalabs.com/api/v1/instance-types | jq '.data | to_entries[] | select(.value.regions_with_capacity_available | length > 0) | .key'
# Try different regions
# US regions: us-west-1, us-east-1, us-south-1
# International: eu-west-1, asia-northeast-1, etc.
# Try alternative GPU types
# H100 not available? Try A100
# A100 not available? Try A10 or A6000
```
### Instance stuck launching
**Problem**: Instance shows "booting" for over 20 minutes
**Solutions**:
```bash
# Single-GPU: Should be ready in 3-5 minutes
# Multi-GPU (8x): May take 10-15 minutes
# If stuck longer:
# 1. Terminate the instance
# 2. Try a different region
# 3. Try a different instance type
# 4. Contact Lambda support if persistent
```
### API authentication fails
**Error**: `401 Unauthorized` or `403 Forbidden`
**Solutions**:
```bash
# Verify API key format (should start with specific prefix)
echo $LAMBDA_API_KEY
# Test API key
curl -u $LAMBDA_API_KEY: \
https://cloud.lambdalabs.com/api/v1/instance-types
# Generate new API key from Lambda console if needed
# Settings > API keys > Generate
```
### Quota limits reached
**Error**: "Instance limit reached" or "Quota exceeded"
**Solutions**:
- Check current running instances in console
- Terminate unused instances
- Contact Lambda support to request quota increase
- Use 1-Click Clusters for large-scale needs
## SSH Connection Issues
### Connection refused
**Error**: `ssh: connect to host <IP> port 22: Connection refused`
**Solutions**:
```bash
# Wait for instance to fully initialize
# Single-GPU: 3-5 minutes
# Multi-GPU: 10-15 minutes
# Check instance status in console (should be "active")
# Verify correct IP address
curl -u $LAMBDA_API_KEY: \
https://cloud.lambdalabs.com/api/v1/instances | jq '.data[].ip'
```
### Permission denied
**Error**: `Permission denied (publickey)`
**Solutions**:
```bash
# Verify SSH key matches
ssh -v -i ~/.ssh/lambda_key ubuntu@<IP>
# Check key permissions
chmod 600 ~/.ssh/lambda_key
chmod 644 ~/.ssh/lambda_key.pub
# Verify key was added to Lambda console before launch
# Keys must be added BEFORE launching instance
# Check authorized_keys on instance (if you have another way in)
cat ~/.ssh/authorized_keys
```
### Host key verification failed
**Error**: `WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!`
**Solutions**:
```bash
# This happens when IP is reused by different instance
# Remove old key
ssh-keygen -R <IP>
# Then connect again
ssh ubuntu@<IP>
```
### Timeout during SSH
**Error**: `ssh: connect to host <IP> port 22: Operation timed out`
**Solutions**:
```bash
# Check if instance is in "active" state
# Verify firewall allows SSH (port 22)
# Lambda console > Firewall
# Check your local network allows outbound SSH
# Try from different network/VPN
```
## GPU Issues
### GPU not detected
**Error**: `nvidia-smi: command not found` or no GPUs shown
**Solutions**:
```bash
# Reboot instance
sudo reboot
# Reinstall NVIDIA drivers (if needed)
wget -nv -O- https://lambdalabs.com/install-lambda-stack.sh | sh -
sudo reboot
# Check driver status
nvidia-smi
lsmod | grep nvidia
```
### CUDA out of memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
```python
# Check GPU memory
import torch
print(torch.cuda.get_device_properties(0).total_memory / 1e9, "GB")
# Clear cache
torch.cuda.empty_cache()
# Reduce batch size
batch_size = batch_size // 2
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Use mixed precision
from torch.cuda.amp import autocast
with autocast():
outputs = model(**inputs)
# Use larger GPU instance
# A100-40GB → A100-80GB → H100
```
### CUDA version mismatch
**Error**: `CUDA driver version is insufficient for CUDA runtime version`
**Solutions**:
```bash
# Check versions
nvidia-smi # Shows driver CUDA version
nvcc --version # Shows toolkit version
# Lambda Stack should have compatible versions
# If mismatch, reinstall Lambda Stack
wget -nv -O- https://lambdalabs.com/install-lambda-stack.sh | sh -
sudo reboot
# Or install specific PyTorch version
pip install torch==2.1.0+cu121 -f https://download.pytorch.org/whl/torch_stable.html
```
### Multi-GPU not working
**Error**: Only one GPU being used
**Solutions**:
```python
# Check all GPUs visible
import torch
print(f"GPUs available: {torch.cuda.device_count()}")
# Verify CUDA_VISIBLE_DEVICES not set restrictively
import os
print(os.environ.get("CUDA_VISIBLE_DEVICES", "not set"))
# Use DataParallel or DistributedDataParallel
model = torch.nn.DataParallel(model)
# or
model = torch.nn.parallel.DistributedDataParallel(model)
```
## Filesystem Issues
### Filesystem not mounted
**Error**: `/lambda/nfs/<name>` doesn't exist
**Solutions**:
```bash
# Filesystem must be attached at launch time
# Cannot attach to running instance
# Verify filesystem was selected during launch
# Check mount points
df -h | grep lambda
# If missing, terminate and relaunch with filesystem
```
### Slow filesystem performance
**Problem**: Reading/writing to filesystem is slow
**Solutions**:
```bash
# Use local SSD for temporary/intermediate files
# /home/ubuntu has fast NVMe storage
# Copy frequently accessed data to local storage
cp -r /lambda/nfs/storage/dataset /home/ubuntu/dataset
# Use filesystem for checkpoints and final outputs only
# Check network bandwidth
iperf3 -c <filesystem_server>
```
### Data lost after termination
**Problem**: Files disappeared after instance terminated
**Solutions**:
```bash
# Root volume (/home/ubuntu) is EPHEMERAL
# Data there is lost on termination
# ALWAYS use filesystem for persistent data
/lambda/nfs/<filesystem_name>/
# Sync important local files before terminating
rsync -av /home/ubuntu/outputs/ /lambda/nfs/storage/outputs/
```
### Filesystem full
**Error**: `No space left on device`
**Solutions**:
```bash
# Check filesystem usage
df -h /lambda/nfs/storage
# Find large files
du -sh /lambda/nfs/storage/* | sort -h
# Clean up old checkpoints
find /lambda/nfs/storage/checkpoints -mtime +7 -delete
# Increase filesystem size in Lambda console
# (may require support request)
```
## Network Issues
### Port not accessible
**Error**: Cannot connect to service (TensorBoard, Jupyter, etc.)
**Solutions**:
```bash
# Lambda default: Only port 22 is open
# Configure firewall in Lambda console
# Or use SSH tunneling (recommended)
ssh -L 6006:localhost:6006 ubuntu@<IP>
# Access at http://localhost:6006
# For Jupyter
ssh -L 8888:localhost:8888 ubuntu@<IP>
```
### Slow data download
**Problem**: Downloading datasets is slow
**Solutions**:
```bash
# Check available bandwidth
speedtest-cli
# Use multi-threaded download
aria2c -x 16 <URL>
# For HuggingFace models
export HF_HUB_ENABLE_HF_TRANSFER=1
pip install hf_transfer
# For S3, use parallel transfer
aws s3 sync s3://bucket/data /local/data --quiet
```
### Inter-node communication fails
**Error**: Distributed training can't connect between nodes
**Solutions**:
```bash
# Verify nodes in same region (required)
# Check private IPs can communicate
ping <other_node_private_ip>
# Verify NCCL settings
export NCCL_DEBUG=INFO
export NCCL_IB_DISABLE=0 # Enable InfiniBand if available
# Check firewall allows distributed ports
# Need: 29500 (PyTorch), or configured MASTER_PORT
```
## Software Issues
### Package installation fails
**Error**: `pip install` errors
**Solutions**:
```bash
# Use virtual environment (don't modify system Python)
python -m venv ~/myenv
source ~/myenv/bin/activate
pip install <package>
# For CUDA packages, match CUDA version
pip install torch --index-url https://download.pytorch.org/whl/cu121
# Clear pip cache if corrupted
pip cache purge
```
### Python version issues
**Error**: Package requires different Python version
**Solutions**:
```bash
# Install alternate Python (don't replace system Python)
sudo apt install python3.11 python3.11-venv python3.11-dev
# Create venv with specific Python
python3.11 -m venv ~/py311env
source ~/py311env/bin/activate
```
### ImportError or ModuleNotFoundError
**Error**: Module not found despite installation
**Solutions**:
```bash
# Verify correct Python environment
which python
pip list | grep <module>
# Ensure virtual environment is activated
source ~/myenv/bin/activate
# Reinstall in correct environment
pip uninstall <package>
pip install <package>
```
## Training Issues
### Training hangs
**Problem**: Training stops progressing, no output
**Solutions**:
```bash
# Check GPU utilization
watch -n 1 nvidia-smi
# If GPUs at 0%, likely data loading bottleneck
# Increase num_workers in DataLoader
# Check for deadlocks in distributed training
export NCCL_DEBUG=INFO
# Add timeouts
dist.init_process_group(..., timeout=timedelta(minutes=30))
```
### Checkpoint corruption
**Error**: `RuntimeError: storage has wrong size` or similar
**Solutions**:
```python
# Use safe saving pattern
checkpoint_path = "/lambda/nfs/storage/checkpoint.pt"
temp_path = checkpoint_path + ".tmp"
# Save to temp first
torch.save(state_dict, temp_path)
# Then atomic rename
os.rename(temp_path, checkpoint_path)
# For loading corrupted checkpoint
try:
state = torch.load(checkpoint_path)
except:
# Fall back to previous checkpoint
state = torch.load(checkpoint_path + ".backup")
```
### Memory leak
**Problem**: Memory usage grows over time
**Solutions**:
```python
# Clear CUDA cache periodically
torch.cuda.empty_cache()
# Detach tensors when logging
loss_value = loss.detach().cpu().item()
# Don't accumulate gradients unintentionally
optimizer.zero_grad(set_to_none=True)
# Use gradient accumulation properly
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
```
## Billing Issues
### Unexpected charges
**Problem**: Bill higher than expected
**Solutions**:
```bash
# Check for forgotten running instances
curl -u $LAMBDA_API_KEY: \
https://cloud.lambdalabs.com/api/v1/instances | jq '.data[].id'
# Terminate all instances
# Lambda console > Instances > Terminate all
# Lambda charges by the minute
# No charge for stopped instances (but no "stop" feature - only terminate)
```
### Instance terminated unexpectedly
**Problem**: Instance disappeared without manual termination
**Possible causes**:
- Payment issue (card declined)
- Account suspension
- Instance health check failure
**Solutions**:
- Check email for Lambda notifications
- Verify payment method in console
- Contact Lambda support
- Always checkpoint to filesystem
## Common Error Messages
| Error | Cause | Solution |
|-------|-------|----------|
| `No capacity available` | Region/GPU sold out | Try different region or GPU type |
| `Permission denied (publickey)` | SSH key mismatch | Re-add key, check permissions |
| `CUDA out of memory` | Model too large | Reduce batch size, use larger GPU |
| `No space left on device` | Disk full | Clean up or use filesystem |
| `Connection refused` | Instance not ready | Wait 3-15 minutes for boot |
| `Module not found` | Wrong Python env | Activate correct virtualenv |
## Getting Help
1. **Documentation**: https://docs.lambda.ai
2. **Support**: https://support.lambdalabs.com
3. **Email**: support@lambdalabs.com
4. **Status**: Check Lambda status page for outages
### Information to Include
When contacting support, include:
- Instance ID
- Region
- Instance type
- Error message (full traceback)
- Steps to reproduce
- Time of occurrence

View file

@ -0,0 +1,344 @@
---
name: modal-serverless-gpu
description: Serverless GPU cloud platform for running ML workloads. Use when you need on-demand GPU access without infrastructure management, deploying ML models as APIs, or running batch jobs with automatic scaling.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [modal>=0.64.0]
metadata:
hermes:
tags: [Infrastructure, Serverless, GPU, Cloud, Deployment, Modal]
---
# Modal Serverless GPU
Comprehensive guide to running ML workloads on Modal's serverless GPU cloud platform.
## When to use Modal
**Use Modal when:**
- Running GPU-intensive ML workloads without managing infrastructure
- Deploying ML models as auto-scaling APIs
- Running batch processing jobs (training, inference, data processing)
- Need pay-per-second GPU pricing without idle costs
- Prototyping ML applications quickly
- Running scheduled jobs (cron-like workloads)
**Key features:**
- **Serverless GPUs**: T4, L4, A10G, L40S, A100, H100, H200, B200 on-demand
- **Python-native**: Define infrastructure in Python code, no YAML
- **Auto-scaling**: Scale to zero, scale to 100+ GPUs instantly
- **Sub-second cold starts**: Rust-based infrastructure for fast container launches
- **Container caching**: Image layers cached for rapid iteration
- **Web endpoints**: Deploy functions as REST APIs with zero-downtime updates
**Use alternatives instead:**
- **RunPod**: For longer-running pods with persistent state
- **Lambda Labs**: For reserved GPU instances
- **SkyPilot**: For multi-cloud orchestration and cost optimization
- **Kubernetes**: For complex multi-service architectures
## Quick start
### Installation
```bash
pip install modal
modal setup # Opens browser for authentication
```
### Hello World with GPU
```python
import modal
app = modal.App("hello-gpu")
@app.function(gpu="T4")
def gpu_info():
import subprocess
return subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout
@app.local_entrypoint()
def main():
print(gpu_info.remote())
```
Run: `modal run hello_gpu.py`
### Basic inference endpoint
```python
import modal
app = modal.App("text-generation")
image = modal.Image.debian_slim().pip_install("transformers", "torch", "accelerate")
@app.cls(gpu="A10G", image=image)
class TextGenerator:
@modal.enter()
def load_model(self):
from transformers import pipeline
self.pipe = pipeline("text-generation", model="gpt2", device=0)
@modal.method()
def generate(self, prompt: str) -> str:
return self.pipe(prompt, max_length=100)[0]["generated_text"]
@app.local_entrypoint()
def main():
print(TextGenerator().generate.remote("Hello, world"))
```
## Core concepts
### Key components
| Component | Purpose |
|-----------|---------|
| `App` | Container for functions and resources |
| `Function` | Serverless function with compute specs |
| `Cls` | Class-based functions with lifecycle hooks |
| `Image` | Container image definition |
| `Volume` | Persistent storage for models/data |
| `Secret` | Secure credential storage |
### Execution modes
| Command | Description |
|---------|-------------|
| `modal run script.py` | Execute and exit |
| `modal serve script.py` | Development with live reload |
| `modal deploy script.py` | Persistent cloud deployment |
## GPU configuration
### Available GPUs
| GPU | VRAM | Best For |
|-----|------|----------|
| `T4` | 16GB | Budget inference, small models |
| `L4` | 24GB | Inference, Ada Lovelace arch |
| `A10G` | 24GB | Training/inference, 3.3x faster than T4 |
| `L40S` | 48GB | Recommended for inference (best cost/perf) |
| `A100-40GB` | 40GB | Large model training |
| `A100-80GB` | 80GB | Very large models |
| `H100` | 80GB | Fastest, FP8 + Transformer Engine |
| `H200` | 141GB | Auto-upgrade from H100, 4.8TB/s bandwidth |
| `B200` | Latest | Blackwell architecture |
### GPU specification patterns
```python
# Single GPU
@app.function(gpu="A100")
# Specific memory variant
@app.function(gpu="A100-80GB")
# Multiple GPUs (up to 8)
@app.function(gpu="H100:4")
# GPU with fallbacks
@app.function(gpu=["H100", "A100", "L40S"])
# Any available GPU
@app.function(gpu="any")
```
## Container images
```python
# Basic image with pip
image = modal.Image.debian_slim(python_version="3.11").pip_install(
"torch==2.1.0", "transformers==4.36.0", "accelerate"
)
# From CUDA base
image = modal.Image.from_registry(
"nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04",
add_python="3.11"
).pip_install("torch", "transformers")
# With system packages
image = modal.Image.debian_slim().apt_install("git", "ffmpeg").pip_install("whisper")
```
## Persistent storage
```python
volume = modal.Volume.from_name("model-cache", create_if_missing=True)
@app.function(gpu="A10G", volumes={"/models": volume})
def load_model():
import os
model_path = "/models/llama-7b"
if not os.path.exists(model_path):
model = download_model()
model.save_pretrained(model_path)
volume.commit() # Persist changes
return load_from_path(model_path)
```
## Web endpoints
### FastAPI endpoint decorator
```python
@app.function()
@modal.fastapi_endpoint(method="POST")
def predict(text: str) -> dict:
return {"result": model.predict(text)}
```
### Full ASGI app
```python
from fastapi import FastAPI
web_app = FastAPI()
@web_app.post("/predict")
async def predict(text: str):
return {"result": await model.predict.remote.aio(text)}
@app.function()
@modal.asgi_app()
def fastapi_app():
return web_app
```
### Web endpoint types
| Decorator | Use Case |
|-----------|----------|
| `@modal.fastapi_endpoint()` | Simple function → API |
| `@modal.asgi_app()` | Full FastAPI/Starlette apps |
| `@modal.wsgi_app()` | Django/Flask apps |
| `@modal.web_server(port)` | Arbitrary HTTP servers |
## Dynamic batching
```python
@app.function()
@modal.batched(max_batch_size=32, wait_ms=100)
async def batch_predict(inputs: list[str]) -> list[dict]:
# Inputs automatically batched
return model.batch_predict(inputs)
```
## Secrets management
```bash
# Create secret
modal secret create huggingface HF_TOKEN=hf_xxx
```
```python
@app.function(secrets=[modal.Secret.from_name("huggingface")])
def download_model():
import os
token = os.environ["HF_TOKEN"]
```
## Scheduling
```python
@app.function(schedule=modal.Cron("0 0 * * *")) # Daily midnight
def daily_job():
pass
@app.function(schedule=modal.Period(hours=1))
def hourly_job():
pass
```
## Performance optimization
### Cold start mitigation
```python
@app.function(
container_idle_timeout=300, # Keep warm 5 min
allow_concurrent_inputs=10, # Handle concurrent requests
)
def inference():
pass
```
### Model loading best practices
```python
@app.cls(gpu="A100")
class Model:
@modal.enter() # Run once at container start
def load(self):
self.model = load_model() # Load during warm-up
@modal.method()
def predict(self, x):
return self.model(x)
```
## Parallel processing
```python
@app.function()
def process_item(item):
return expensive_computation(item)
@app.function()
def run_parallel():
items = list(range(1000))
# Fan out to parallel containers
results = list(process_item.map(items))
return results
```
## Common configuration
```python
@app.function(
gpu="A100",
memory=32768, # 32GB RAM
cpu=4, # 4 CPU cores
timeout=3600, # 1 hour max
container_idle_timeout=120,# Keep warm 2 min
retries=3, # Retry on failure
concurrency_limit=10, # Max concurrent containers
)
def my_function():
pass
```
## Debugging
```python
# Test locally
if __name__ == "__main__":
result = my_function.local()
# View logs
# modal app logs my-app
```
## Common issues
| Issue | Solution |
|-------|----------|
| Cold start latency | Increase `container_idle_timeout`, use `@modal.enter()` |
| GPU OOM | Use larger GPU (`A100-80GB`), enable gradient checkpointing |
| Image build fails | Pin dependency versions, check CUDA compatibility |
| Timeout errors | Increase `timeout`, add checkpointing |
## References
- **[Advanced Usage](references/advanced-usage.md)** - Multi-GPU, distributed training, cost optimization
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
## Resources
- **Documentation**: https://modal.com/docs
- **Examples**: https://github.com/modal-labs/modal-examples
- **Pricing**: https://modal.com/pricing
- **Discord**: https://discord.gg/modal

View file

@ -0,0 +1,503 @@
# Modal Advanced Usage Guide
## Multi-GPU Training
### Single-node multi-GPU
```python
import modal
app = modal.App("multi-gpu-training")
image = modal.Image.debian_slim().pip_install("torch", "transformers", "accelerate")
@app.function(gpu="H100:4", image=image, timeout=7200)
def train_multi_gpu():
from accelerate import Accelerator
accelerator = Accelerator()
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
```
### DeepSpeed integration
```python
image = modal.Image.debian_slim().pip_install(
"torch", "transformers", "deepspeed", "accelerate"
)
@app.function(gpu="A100:8", image=image, timeout=14400)
def deepspeed_train(config: dict):
from transformers import Trainer, TrainingArguments
args = TrainingArguments(
output_dir="/outputs",
deepspeed="ds_config.json",
fp16=True,
per_device_train_batch_size=4,
gradient_accumulation_steps=4
)
trainer = Trainer(model=model, args=args, train_dataset=dataset)
trainer.train()
```
### Multi-GPU considerations
For frameworks that re-execute the Python entrypoint (like PyTorch Lightning), use:
- `ddp_spawn` or `ddp_notebook` strategy
- Run training as a subprocess to avoid issues
```python
@app.function(gpu="H100:4")
def train_with_subprocess():
import subprocess
subprocess.run(["python", "-m", "torch.distributed.launch", "train.py"])
```
## Advanced Container Configuration
### Multi-stage builds for caching
```python
# Stage 1: Base dependencies (cached)
base_image = modal.Image.debian_slim().pip_install("torch", "numpy", "scipy")
# Stage 2: ML libraries (cached separately)
ml_image = base_image.pip_install("transformers", "datasets", "accelerate")
# Stage 3: Custom code (rebuilt on changes)
final_image = ml_image.copy_local_dir("./src", "/app/src")
```
### Custom Dockerfiles
```python
image = modal.Image.from_dockerfile("./Dockerfile")
```
### Installing from Git
```python
image = modal.Image.debian_slim().pip_install(
"git+https://github.com/huggingface/transformers.git@main"
)
```
### Using uv for faster installs
```python
image = modal.Image.debian_slim().uv_pip_install(
"torch", "transformers", "accelerate"
)
```
## Advanced Class Patterns
### Lifecycle hooks
```python
@app.cls(gpu="A10G")
class InferenceService:
@modal.enter()
def startup(self):
"""Called once when container starts"""
self.model = load_model()
self.tokenizer = load_tokenizer()
@modal.exit()
def shutdown(self):
"""Called when container shuts down"""
cleanup_resources()
@modal.method()
def predict(self, text: str):
return self.model(self.tokenizer(text))
```
### Concurrent request handling
```python
@app.cls(
gpu="A100",
allow_concurrent_inputs=20, # Handle 20 requests per container
container_idle_timeout=300
)
class BatchInference:
@modal.enter()
def load(self):
self.model = load_model()
@modal.method()
def predict(self, inputs: list):
return self.model.batch_predict(inputs)
```
### Input concurrency vs batching
- **Input concurrency**: Multiple requests processed simultaneously (async I/O)
- **Dynamic batching**: Requests accumulated and processed together (GPU efficiency)
```python
# Input concurrency - good for I/O-bound
@app.function(allow_concurrent_inputs=10)
async def fetch_data(url: str):
async with aiohttp.ClientSession() as session:
return await session.get(url)
# Dynamic batching - good for GPU inference
@app.function()
@modal.batched(max_batch_size=32, wait_ms=100)
async def batch_embed(texts: list[str]) -> list[list[float]]:
return model.encode(texts)
```
## Advanced Volumes
### Volume operations
```python
volume = modal.Volume.from_name("my-volume", create_if_missing=True)
@app.function(volumes={"/data": volume})
def volume_operations():
import os
# Write data
with open("/data/output.txt", "w") as f:
f.write("Results")
# Commit changes (persist to volume)
volume.commit()
# Reload from remote (get latest)
volume.reload()
```
### Shared volumes between functions
```python
shared_volume = modal.Volume.from_name("shared-data", create_if_missing=True)
@app.function(volumes={"/shared": shared_volume})
def writer():
with open("/shared/data.txt", "w") as f:
f.write("Hello from writer")
shared_volume.commit()
@app.function(volumes={"/shared": shared_volume})
def reader():
shared_volume.reload() # Get latest
with open("/shared/data.txt", "r") as f:
return f.read()
```
### Cloud bucket mounts
```python
# Mount S3 bucket
bucket = modal.CloudBucketMount(
bucket_name="my-bucket",
secret=modal.Secret.from_name("aws-credentials")
)
@app.function(volumes={"/s3": bucket})
def process_s3_data():
# Access S3 files like local filesystem
data = open("/s3/data.parquet").read()
```
## Function Composition
### Chaining functions
```python
@app.function()
def preprocess(data):
return cleaned_data
@app.function(gpu="T4")
def inference(data):
return predictions
@app.function()
def postprocess(predictions):
return formatted_results
@app.function()
def pipeline(raw_data):
cleaned = preprocess.remote(raw_data)
predictions = inference.remote(cleaned)
results = postprocess.remote(predictions)
return results
```
### Parallel fan-out
```python
@app.function()
def process_item(item):
return expensive_computation(item)
@app.function()
def parallel_pipeline(items):
# Fan out: process all items in parallel
results = list(process_item.map(items))
return results
```
### Starmap for multiple arguments
```python
@app.function()
def process(x, y, z):
return x + y + z
@app.function()
def orchestrate():
args = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
results = list(process.starmap(args))
return results
```
## Advanced Web Endpoints
### WebSocket support
```python
from fastapi import FastAPI, WebSocket
app = modal.App("websocket-app")
web_app = FastAPI()
@web_app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Processed: {data}")
@app.function()
@modal.asgi_app()
def ws_app():
return web_app
```
### Streaming responses
```python
from fastapi.responses import StreamingResponse
@app.function(gpu="A100")
def generate_stream(prompt: str):
for token in model.generate_stream(prompt):
yield token
@web_app.get("/stream")
async def stream_response(prompt: str):
return StreamingResponse(
generate_stream.remote_gen(prompt),
media_type="text/event-stream"
)
```
### Authentication
```python
from fastapi import Depends, HTTPException, Header
async def verify_token(authorization: str = Header(None)):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401)
token = authorization.split(" ")[1]
if not verify_jwt(token):
raise HTTPException(status_code=403)
return token
@web_app.post("/predict")
async def predict(data: dict, token: str = Depends(verify_token)):
return model.predict(data)
```
## Cost Optimization
### Right-sizing GPUs
```python
# For inference: smaller GPUs often sufficient
@app.function(gpu="L40S") # 48GB, best cost/perf for inference
def inference():
pass
# For training: larger GPUs for throughput
@app.function(gpu="A100-80GB")
def training():
pass
```
### GPU fallbacks for availability
```python
@app.function(gpu=["H100", "A100", "L40S"]) # Try in order
def flexible_compute():
pass
```
### Scale to zero
```python
# Default behavior: scale to zero when idle
@app.function(gpu="A100")
def on_demand():
pass
# Keep containers warm for low latency (costs more)
@app.function(gpu="A100", keep_warm=1)
def always_ready():
pass
```
### Batch processing for efficiency
```python
# Process in batches to reduce cold starts
@app.function(gpu="A100")
def batch_process(items: list):
return [process(item) for item in items]
# Better than individual calls
results = batch_process.remote(all_items)
```
## Monitoring and Observability
### Structured logging
```python
import json
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.function()
def structured_logging(request_id: str, data: dict):
logger.info(json.dumps({
"event": "inference_start",
"request_id": request_id,
"input_size": len(data)
}))
result = process(data)
logger.info(json.dumps({
"event": "inference_complete",
"request_id": request_id,
"output_size": len(result)
}))
return result
```
### Custom metrics
```python
@app.function(gpu="A100")
def monitored_inference(inputs):
import time
start = time.time()
results = model.predict(inputs)
latency = time.time() - start
# Log metrics (visible in Modal dashboard)
print(f"METRIC latency={latency:.3f}s batch_size={len(inputs)}")
return results
```
## Production Deployment
### Environment separation
```python
import os
env = os.environ.get("MODAL_ENV", "dev")
app = modal.App(f"my-service-{env}")
# Environment-specific config
if env == "prod":
gpu_config = "A100"
timeout = 3600
else:
gpu_config = "T4"
timeout = 300
```
### Zero-downtime deployments
Modal automatically handles zero-downtime deployments:
1. New containers are built and started
2. Traffic gradually shifts to new version
3. Old containers drain existing requests
4. Old containers are terminated
### Health checks
```python
@app.function()
@modal.web_endpoint()
def health():
return {
"status": "healthy",
"model_loaded": hasattr(Model, "_model"),
"gpu_available": torch.cuda.is_available()
}
```
## Sandboxes
### Interactive execution environments
```python
@app.function()
def run_sandbox():
sandbox = modal.Sandbox.create(
app=app,
image=image,
gpu="T4"
)
# Execute code in sandbox
result = sandbox.exec("python", "-c", "print('Hello from sandbox')")
sandbox.terminate()
return result
```
## Invoking Deployed Functions
### From external code
```python
# Call deployed function from any Python script
import modal
f = modal.Function.lookup("my-app", "my_function")
result = f.remote(arg1, arg2)
```
### REST API invocation
```bash
# Deployed endpoints accessible via HTTPS
curl -X POST https://your-workspace--my-app-predict.modal.run \
-H "Content-Type: application/json" \
-d '{"text": "Hello world"}'
```

View file

@ -0,0 +1,494 @@
# Modal Troubleshooting Guide
## Installation Issues
### Authentication fails
**Error**: `modal setup` doesn't complete or token is invalid
**Solutions**:
```bash
# Re-authenticate
modal token new
# Check current token
modal config show
# Set token via environment
export MODAL_TOKEN_ID=ak-...
export MODAL_TOKEN_SECRET=as-...
```
### Package installation issues
**Error**: `pip install modal` fails
**Solutions**:
```bash
# Upgrade pip
pip install --upgrade pip
# Install with specific Python version
python3.11 -m pip install modal
# Install from wheel
pip install modal --prefer-binary
```
## Container Image Issues
### Image build fails
**Error**: `ImageBuilderError: Failed to build image`
**Solutions**:
```python
# Pin package versions to avoid conflicts
image = modal.Image.debian_slim().pip_install(
"torch==2.1.0",
"transformers==4.36.0", # Pin versions
"accelerate==0.25.0"
)
# Use compatible CUDA versions
image = modal.Image.from_registry(
"nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04", # Match PyTorch CUDA
add_python="3.11"
)
```
### Dependency conflicts
**Error**: `ERROR: Cannot install package due to conflicting dependencies`
**Solutions**:
```python
# Layer dependencies separately
base = modal.Image.debian_slim().pip_install("torch")
ml = base.pip_install("transformers") # Install after torch
# Use uv for better resolution
image = modal.Image.debian_slim().uv_pip_install(
"torch", "transformers"
)
```
### Large image builds timeout
**Error**: Image build exceeds time limit
**Solutions**:
```python
# Split into multiple layers (better caching)
base = modal.Image.debian_slim().pip_install("torch") # Cached
ml = base.pip_install("transformers", "datasets") # Cached
app = ml.copy_local_dir("./src", "/app") # Rebuilds on code change
# Download models during build, not runtime
image = modal.Image.debian_slim().pip_install("transformers").run_commands(
"python -c 'from transformers import AutoModel; AutoModel.from_pretrained(\"bert-base\")'"
)
```
## GPU Issues
### GPU not available
**Error**: `RuntimeError: CUDA not available`
**Solutions**:
```python
# Ensure GPU is specified
@app.function(gpu="T4") # Must specify GPU
def my_function():
import torch
assert torch.cuda.is_available()
# Check CUDA compatibility in image
image = modal.Image.from_registry(
"nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04",
add_python="3.11"
).pip_install(
"torch",
index_url="https://download.pytorch.org/whl/cu121" # Match CUDA
)
```
### GPU out of memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
```python
# Use larger GPU
@app.function(gpu="A100-80GB") # More VRAM
def train():
pass
# Enable memory optimization
@app.function(gpu="A100")
def memory_optimized():
import torch
torch.backends.cuda.enable_flash_sdp(True)
# Use gradient checkpointing
model.gradient_checkpointing_enable()
# Mixed precision
with torch.autocast(device_type="cuda", dtype=torch.float16):
outputs = model(**inputs)
```
### Wrong GPU allocated
**Error**: Got different GPU than requested
**Solutions**:
```python
# Use strict GPU selection
@app.function(gpu="H100!") # H100! prevents auto-upgrade to H200
# Specify exact memory variant
@app.function(gpu="A100-80GB") # Not just "A100"
# Check GPU at runtime
@app.function(gpu="A100")
def check_gpu():
import subprocess
result = subprocess.run(["nvidia-smi"], capture_output=True, text=True)
print(result.stdout)
```
## Cold Start Issues
### Slow cold starts
**Problem**: First request takes too long
**Solutions**:
```python
# Keep containers warm
@app.function(
container_idle_timeout=600, # Keep warm 10 min
keep_warm=1 # Always keep 1 container ready
)
def low_latency():
pass
# Load model during container start
@app.cls(gpu="A100")
class Model:
@modal.enter()
def load(self):
# This runs once at container start, not per request
self.model = load_heavy_model()
# Cache model in volume
volume = modal.Volume.from_name("models", create_if_missing=True)
@app.function(volumes={"/cache": volume})
def cached_model():
if os.path.exists("/cache/model"):
model = load_from_disk("/cache/model")
else:
model = download_model()
save_to_disk(model, "/cache/model")
volume.commit()
```
### Container keeps restarting
**Problem**: Containers are killed and restarted frequently
**Solutions**:
```python
# Increase memory
@app.function(memory=32768) # 32GB RAM
def memory_heavy():
pass
# Increase timeout
@app.function(timeout=3600) # 1 hour
def long_running():
pass
# Handle signals gracefully
import signal
def handler(signum, frame):
cleanup()
exit(0)
signal.signal(signal.SIGTERM, handler)
```
## Volume Issues
### Volume changes not persisting
**Error**: Data written to volume disappears
**Solutions**:
```python
volume = modal.Volume.from_name("my-volume", create_if_missing=True)
@app.function(volumes={"/data": volume})
def write_data():
with open("/data/file.txt", "w") as f:
f.write("data")
# CRITICAL: Commit changes!
volume.commit()
```
### Volume read shows stale data
**Error**: Reading outdated data from volume
**Solutions**:
```python
@app.function(volumes={"/data": volume})
def read_data():
# Reload to get latest
volume.reload()
with open("/data/file.txt", "r") as f:
return f.read()
```
### Volume mount fails
**Error**: `VolumeError: Failed to mount volume`
**Solutions**:
```python
# Ensure volume exists
volume = modal.Volume.from_name("my-volume", create_if_missing=True)
# Use absolute path
@app.function(volumes={"/data": volume}) # Not "./data"
def my_function():
pass
# Check volume in dashboard
# modal volume list
```
## Web Endpoint Issues
### Endpoint returns 502
**Error**: Gateway timeout or bad gateway
**Solutions**:
```python
# Increase timeout
@app.function(timeout=300) # 5 min
@modal.web_endpoint()
def slow_endpoint():
pass
# Return streaming response for long operations
from fastapi.responses import StreamingResponse
@app.function()
@modal.asgi_app()
def streaming_app():
async def generate():
for i in range(100):
yield f"data: {i}\n\n"
await process_chunk(i)
return StreamingResponse(generate(), media_type="text/event-stream")
```
### Endpoint not accessible
**Error**: 404 or cannot reach endpoint
**Solutions**:
```bash
# Check deployment status
modal app list
# Redeploy
modal deploy my_app.py
# Check logs
modal app logs my-app
```
### CORS errors
**Error**: Cross-origin request blocked
**Solutions**:
```python
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
web_app = FastAPI()
web_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.function()
@modal.asgi_app()
def cors_enabled():
return web_app
```
## Secret Issues
### Secret not found
**Error**: `SecretNotFound: Secret 'my-secret' not found`
**Solutions**:
```bash
# Create secret via CLI
modal secret create my-secret KEY=value
# List secrets
modal secret list
# Check secret name matches exactly
```
### Secret value not accessible
**Error**: Environment variable is empty
**Solutions**:
```python
# Ensure secret is attached
@app.function(secrets=[modal.Secret.from_name("my-secret")])
def use_secret():
import os
value = os.environ.get("KEY") # Use get() to handle missing
if not value:
raise ValueError("KEY not set in secret")
```
## Scheduling Issues
### Scheduled job not running
**Error**: Cron job doesn't execute
**Solutions**:
```python
# Verify cron syntax
@app.function(schedule=modal.Cron("0 0 * * *")) # Daily at midnight UTC
def daily_job():
pass
# Check timezone (Modal uses UTC)
# "0 8 * * *" = 8am UTC, not local time
# Ensure app is deployed
# modal deploy my_app.py
```
### Job runs multiple times
**Problem**: Scheduled job executes more than expected
**Solutions**:
```python
# Implement idempotency
@app.function(schedule=modal.Cron("0 * * * *"))
def hourly_job():
job_id = get_current_hour_id()
if already_processed(job_id):
return
process()
mark_processed(job_id)
```
## Debugging Tips
### Enable debug logging
```python
import logging
logging.basicConfig(level=logging.DEBUG)
@app.function()
def debug_function():
logging.debug("Debug message")
logging.info("Info message")
```
### View container logs
```bash
# Stream logs
modal app logs my-app
# View specific function
modal app logs my-app --function my_function
# View historical logs
modal app logs my-app --since 1h
```
### Test locally
```python
# Run function locally without Modal
if __name__ == "__main__":
result = my_function.local() # Runs on your machine
print(result)
```
### Inspect container
```python
@app.function(gpu="T4")
def debug_environment():
import subprocess
import sys
# System info
print(f"Python: {sys.version}")
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
print(subprocess.run(["pip", "list"], capture_output=True, text=True).stdout)
# CUDA info
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
```
## Common Error Messages
| Error | Cause | Solution |
|-------|-------|----------|
| `FunctionTimeoutError` | Function exceeded timeout | Increase `timeout` parameter |
| `ContainerMemoryExceeded` | OOM killed | Increase `memory` parameter |
| `ImageBuilderError` | Build failed | Check dependencies, pin versions |
| `ResourceExhausted` | No GPUs available | Use GPU fallbacks, try later |
| `AuthenticationError` | Invalid token | Run `modal token new` |
| `VolumeNotFound` | Volume doesn't exist | Use `create_if_missing=True` |
| `SecretNotFound` | Secret doesn't exist | Create secret via CLI |
## Getting Help
1. **Documentation**: https://modal.com/docs
2. **Examples**: https://github.com/modal-labs/modal-examples
3. **Discord**: https://discord.gg/modal
4. **Status**: https://status.modal.com
### Reporting Issues
Include:
- Modal client version: `modal --version`
- Python version: `python --version`
- Full error traceback
- Minimal reproducible code
- GPU type if relevant

View file

@ -0,0 +1,3 @@
---
description: Model evaluation benchmarks, experiment tracking, data curation, tokenizers, and interpretability tools.
---

View file

@ -0,0 +1,519 @@
---
name: huggingface-tokenizers
description: Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integrates seamlessly with transformers. Use when you need high-performance tokenization or custom tokenizer training.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [tokenizers, transformers, datasets]
metadata:
hermes:
tags: [Tokenization, HuggingFace, BPE, WordPiece, Unigram, Fast Tokenization, Rust, Custom Tokenizer, Alignment Tracking, Production]
---
# HuggingFace Tokenizers - Fast Tokenization for NLP
Fast, production-ready tokenizers with Rust performance and Python ease-of-use.
## When to use HuggingFace Tokenizers
**Use HuggingFace Tokenizers when:**
- Need extremely fast tokenization (<20s per GB of text)
- Training custom tokenizers from scratch
- Want alignment tracking (token → original text position)
- Building production NLP pipelines
- Need to tokenize large corpora efficiently
**Performance**:
- **Speed**: <20 seconds to tokenize 1GB on CPU
- **Implementation**: Rust core with Python/Node.js bindings
- **Efficiency**: 10-100× faster than pure Python implementations
**Use alternatives instead**:
- **SentencePiece**: Language-independent, used by T5/ALBERT
- **tiktoken**: OpenAI's BPE tokenizer for GPT models
- **transformers AutoTokenizer**: Loading pretrained only (uses this library internally)
## Quick start
### Installation
```bash
# Install tokenizers
pip install tokenizers
# With transformers integration
pip install tokenizers transformers
```
### Load pretrained tokenizer
```python
from tokenizers import Tokenizer
# Load from HuggingFace Hub
tokenizer = Tokenizer.from_pretrained("bert-base-uncased")
# Encode text
output = tokenizer.encode("Hello, how are you?")
print(output.tokens) # ['hello', ',', 'how', 'are', 'you', '?']
print(output.ids) # [7592, 1010, 2129, 2024, 2017, 1029]
# Decode back
text = tokenizer.decode(output.ids)
print(text) # "hello, how are you?"
```
### Train custom BPE tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
# Initialize tokenizer with BPE model
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
# Configure trainer
trainer = BpeTrainer(
vocab_size=30000,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
min_frequency=2
)
# Train on files
files = ["train.txt", "validation.txt"]
tokenizer.train(files, trainer)
# Save
tokenizer.save("my-tokenizer.json")
```
**Training time**: ~1-2 minutes for 100MB corpus, ~10-20 minutes for 1GB
### Batch encoding with padding
```python
# Enable padding
tokenizer.enable_padding(pad_id=3, pad_token="[PAD]")
# Encode batch
texts = ["Hello world", "This is a longer sentence"]
encodings = tokenizer.encode_batch(texts)
for encoding in encodings:
print(encoding.ids)
# [101, 7592, 2088, 102, 3, 3, 3]
# [101, 2023, 2003, 1037, 2936, 6251, 102]
```
## Tokenization algorithms
### BPE (Byte-Pair Encoding)
**How it works**:
1. Start with character-level vocabulary
2. Find most frequent character pair
3. Merge into new token, add to vocabulary
4. Repeat until vocabulary size reached
**Used by**: GPT-2, GPT-3, RoBERTa, BART, DeBERTa
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
tokenizer = Tokenizer(BPE(unk_token="<|endoftext|>"))
tokenizer.pre_tokenizer = ByteLevel()
trainer = BpeTrainer(
vocab_size=50257,
special_tokens=["<|endoftext|>"],
min_frequency=2
)
tokenizer.train(files=["data.txt"], trainer=trainer)
```
**Advantages**:
- Handles OOV words well (breaks into subwords)
- Flexible vocabulary size
- Good for morphologically rich languages
**Trade-offs**:
- Tokenization depends on merge order
- May split common words unexpectedly
### WordPiece
**How it works**:
1. Start with character vocabulary
2. Score merge pairs: `frequency(pair) / (frequency(first) × frequency(second))`
3. Merge highest scoring pair
4. Repeat until vocabulary size reached
**Used by**: BERT, DistilBERT, MobileBERT
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import BertNormalizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
tokenizer.normalizer = BertNormalizer(lowercase=True)
tokenizer.pre_tokenizer = Whitespace()
trainer = WordPieceTrainer(
vocab_size=30522,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##"
)
tokenizer.train(files=["corpus.txt"], trainer=trainer)
```
**Advantages**:
- Prioritizes meaningful merges (high score = semantically related)
- Used successfully in BERT (state-of-the-art results)
**Trade-offs**:
- Unknown words become `[UNK]` if no subword match
- Saves vocabulary, not merge rules (larger files)
### Unigram
**How it works**:
1. Start with large vocabulary (all substrings)
2. Compute loss for corpus with current vocabulary
3. Remove tokens with minimal impact on loss
4. Repeat until vocabulary size reached
**Used by**: ALBERT, T5, mBART, XLNet (via SentencePiece)
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
tokenizer = Tokenizer(Unigram())
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>"
)
tokenizer.train(files=["data.txt"], trainer=trainer)
```
**Advantages**:
- Probabilistic (finds most likely tokenization)
- Works well for languages without word boundaries
- Handles diverse linguistic contexts
**Trade-offs**:
- Computationally expensive to train
- More hyperparameters to tune
## Tokenization pipeline
Complete pipeline: **Normalization → Pre-tokenization → Model → Post-processing**
### Normalization
Clean and standardize text:
```python
from tokenizers.normalizers import NFD, StripAccents, Lowercase, Sequence
tokenizer.normalizer = Sequence([
NFD(), # Unicode normalization (decompose)
Lowercase(), # Convert to lowercase
StripAccents() # Remove accents
])
# Input: "Héllo WORLD"
# After normalization: "hello world"
```
**Common normalizers**:
- `NFD`, `NFC`, `NFKD`, `NFKC` - Unicode normalization forms
- `Lowercase()` - Convert to lowercase
- `StripAccents()` - Remove accents (é → e)
- `Strip()` - Remove whitespace
- `Replace(pattern, content)` - Regex replacement
### Pre-tokenization
Split text into word-like units:
```python
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence, ByteLevel
# Split on whitespace and punctuation
tokenizer.pre_tokenizer = Sequence([
Whitespace(),
Punctuation()
])
# Input: "Hello, world!"
# After pre-tokenization: ["Hello", ",", "world", "!"]
```
**Common pre-tokenizers**:
- `Whitespace()` - Split on spaces, tabs, newlines
- `ByteLevel()` - GPT-2 style byte-level splitting
- `Punctuation()` - Isolate punctuation
- `Digits(individual_digits=True)` - Split digits individually
- `Metaspace()` - Replace spaces with ▁ (SentencePiece style)
### Post-processing
Add special tokens for model input:
```python
from tokenizers.processors import TemplateProcessing
# BERT-style: [CLS] sentence [SEP]
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", 1),
("[SEP]", 2),
],
)
```
**Common patterns**:
```python
# GPT-2: sentence <|endoftext|>
TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[("<|endoftext|>", 50256)]
)
# RoBERTa: <s> sentence </s>
TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[("<s>", 0), ("</s>", 2)]
)
```
## Alignment tracking
Track token positions in original text:
```python
output = tokenizer.encode("Hello, world!")
# Get token offsets
for token, offset in zip(output.tokens, output.offsets):
start, end = offset
print(f"{token:10} → [{start:2}, {end:2}): {text[start:end]!r}")
# Output:
# hello → [ 0, 5): 'Hello'
# , → [ 5, 6): ','
# world → [ 7, 12): 'world'
# ! → [12, 13): '!'
```
**Use cases**:
- Named entity recognition (map predictions back to text)
- Question answering (extract answer spans)
- Token classification (align labels to original positions)
## Integration with transformers
### Load with AutoTokenizer
```python
from transformers import AutoTokenizer
# AutoTokenizer automatically uses fast tokenizers
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Check if using fast tokenizer
print(tokenizer.is_fast) # True
# Access underlying tokenizers.Tokenizer
fast_tokenizer = tokenizer.backend_tokenizer
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
```
### Convert custom tokenizer to transformers
```python
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
# Train custom tokenizer
tokenizer = Tokenizer(BPE())
# ... train tokenizer ...
tokenizer.save("my-tokenizer.json")
# Wrap for transformers
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_file="my-tokenizer.json",
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]"
)
# Use like any transformers tokenizer
outputs = transformers_tokenizer(
"Hello world",
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
```
## Common patterns
### Train from iterator (large datasets)
```python
from datasets import load_dataset
# Load dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
# Create batch iterator
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i:i + batch_size]["text"]
# Train tokenizer
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer,
length=len(dataset) # For progress bar
)
```
**Performance**: Processes 1GB in ~10-20 minutes
### Enable truncation and padding
```python
# Enable truncation
tokenizer.enable_truncation(max_length=512)
# Enable padding
tokenizer.enable_padding(
pad_id=tokenizer.token_to_id("[PAD]"),
pad_token="[PAD]",
length=512 # Fixed length, or None for batch max
)
# Encode with both
output = tokenizer.encode("This is a long sentence that will be truncated...")
print(len(output.ids)) # 512
```
### Multi-processing
```python
from tokenizers import Tokenizer
from multiprocessing import Pool
# Load tokenizer
tokenizer = Tokenizer.from_file("tokenizer.json")
def encode_batch(texts):
return tokenizer.encode_batch(texts)
# Process large corpus in parallel
with Pool(8) as pool:
# Split corpus into chunks
chunk_size = 1000
chunks = [corpus[i:i+chunk_size] for i in range(0, len(corpus), chunk_size)]
# Encode in parallel
results = pool.map(encode_batch, chunks)
```
**Speedup**: 5-8× with 8 cores
## Performance benchmarks
### Training speed
| Corpus Size | BPE (30k vocab) | WordPiece (30k) | Unigram (8k) |
|-------------|-----------------|-----------------|--------------|
| 10 MB | 15 sec | 18 sec | 25 sec |
| 100 MB | 1.5 min | 2 min | 4 min |
| 1 GB | 15 min | 20 min | 40 min |
**Hardware**: 16-core CPU, tested on English Wikipedia
### Tokenization speed
| Implementation | 1 GB corpus | Throughput |
|----------------|-------------|---------------|
| Pure Python | ~20 minutes | ~50 MB/min |
| HF Tokenizers | ~15 seconds | ~4 GB/min |
| **Speedup** | **80×** | **80×** |
**Test**: English text, average sentence length 20 words
### Memory usage
| Task | Memory |
|-------------------------|---------|
| Load tokenizer | ~10 MB |
| Train BPE (30k vocab) | ~200 MB |
| Encode 1M sentences | ~500 MB |
## Supported models
Pre-trained tokenizers available via `from_pretrained()`:
**BERT family**:
- `bert-base-uncased`, `bert-large-cased`
- `distilbert-base-uncased`
- `roberta-base`, `roberta-large`
**GPT family**:
- `gpt2`, `gpt2-medium`, `gpt2-large`
- `distilgpt2`
**T5 family**:
- `t5-small`, `t5-base`, `t5-large`
- `google/flan-t5-xxl`
**Other**:
- `facebook/bart-base`, `facebook/mbart-large-cc25`
- `albert-base-v2`, `albert-xlarge-v2`
- `xlm-roberta-base`, `xlm-roberta-large`
Browse all: https://huggingface.co/models?library=tokenizers
## References
- **[Training Guide](references/training.md)** - Train custom tokenizers, configure trainers, handle large datasets
- **[Algorithms Deep Dive](references/algorithms.md)** - BPE, WordPiece, Unigram explained in detail
- **[Pipeline Components](references/pipeline.md)** - Normalizers, pre-tokenizers, post-processors, decoders
- **[Transformers Integration](references/integration.md)** - AutoTokenizer, PreTrainedTokenizerFast, special tokens
## Resources
- **Docs**: https://huggingface.co/docs/tokenizers
- **GitHub**: https://github.com/huggingface/tokenizers ⭐ 9,000+
- **Version**: 0.20.0+
- **Course**: https://huggingface.co/learn/nlp-course/chapter6/1
- **Paper**: BPE (Sennrich et al., 2016), WordPiece (Schuster & Nakajima, 2012)

View file

@ -0,0 +1,653 @@
# Tokenization Algorithms Deep Dive
Comprehensive explanation of BPE, WordPiece, and Unigram algorithms.
## Byte-Pair Encoding (BPE)
### Algorithm overview
BPE iteratively merges the most frequent pair of tokens in a corpus.
**Training process**:
1. Initialize vocabulary with all characters
2. Count frequency of all adjacent token pairs
3. Merge most frequent pair into new token
4. Add new token to vocabulary
5. Update corpus with new token
6. Repeat until vocabulary size reached
### Step-by-step example
**Corpus**:
```
low: 5
lower: 2
newest: 6
widest: 3
```
**Iteration 1**:
```
Count pairs:
'e' + 's': 9 (newest: 6, widest: 3) ← most frequent
'l' + 'o': 7
'o' + 'w': 7
...
Merge: 'e' + 's' → 'es'
Updated corpus:
low: 5
lower: 2
newest: 6 → newes|t: 6
widest: 3 → wides|t: 3
Vocabulary: [a-z] + ['es']
```
**Iteration 2**:
```
Count pairs:
'es' + 't': 9 ← most frequent
'l' + 'o': 7
...
Merge: 'es' + 't' → 'est'
Updated corpus:
low: 5
lower: 2
newest: 6 → new|est: 6
widest: 3 → wid|est: 3
Vocabulary: [a-z] + ['es', 'est']
```
**Continue until desired vocabulary size...**
### Tokenization with trained BPE
Given vocabulary: `['l', 'o', 'w', 'e', 'r', 'n', 's', 't', 'i', 'd', 'es', 'est', 'lo', 'low', 'ne', 'new', 'newest', 'wi', 'wid', 'widest']`
Tokenize "lowest":
```
Step 1: Split into characters
['l', 'o', 'w', 'e', 's', 't']
Step 2: Apply merges in order learned during training
- Merge 'l' + 'o' → 'lo' (if this merge was learned)
- Merge 'lo' + 'w' → 'low' (if learned)
- Merge 'e' + 's' → 'es' (learned)
- Merge 'es' + 't' → 'est' (learned)
Final: ['low', 'est']
```
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
# Initialize
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
# Configure trainer
trainer = BpeTrainer(
vocab_size=1000,
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
# Train
corpus = [
"This is a sample corpus for BPE training.",
"BPE learns subword units from the training data.",
# ... more sentences
]
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("This is tokenization")
print(output.tokens) # ['This', 'is', 'token', 'ization']
```
### Byte-level BPE (GPT-2 variant)
**Problem**: Standard BPE has limited character coverage (256+ Unicode chars)
**Solution**: Operate on byte level (256 bytes)
```python
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
tokenizer = Tokenizer(BPE())
# Byte-level pre-tokenization
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
# This handles ALL possible characters, including emojis
text = "Hello 🌍 世界"
tokens = tokenizer.encode(text).tokens
```
**Advantages**:
- Handles any Unicode character (256 byte coverage)
- No unknown tokens (worst case: bytes)
- Used by GPT-2, GPT-3, BART
**Trade-offs**:
- Slightly worse compression (bytes vs characters)
- More tokens for non-ASCII text
### BPE variants
**SentencePiece BPE**:
- Language-independent (no pre-tokenization)
- Treats input as raw byte stream
- Used by T5, ALBERT, XLNet
**Robust BPE**:
- Dropout during training (randomly skip merges)
- More robust tokenization at inference
- Reduces overfitting to training data
## WordPiece
### Algorithm overview
WordPiece is similar to BPE but uses a different merge selection criterion.
**Training process**:
1. Initialize vocabulary with all characters
2. Count frequency of all token pairs
3. Score each pair: `score = freq(pair) / (freq(first) × freq(second))`
4. Merge pair with highest score
5. Repeat until vocabulary size reached
### Why different scoring?
**BPE**: Merges most frequent pairs
- "aa" appears 100 times → high priority
- Even if 'a' appears 1000 times alone
**WordPiece**: Merges pairs that are semantically related
- "aa" appears 100 times, 'a' appears 1000 times → low score (100 / (1000 × 1000))
- "th" appears 50 times, 't' appears 60 times, 'h' appears 55 times → high score (50 / (60 × 55))
- Prioritizes pairs that appear together more than expected
### Step-by-step example
**Corpus**:
```
low: 5
lower: 2
newest: 6
widest: 3
```
**Iteration 1**:
```
Count frequencies:
'e': 11 (lower: 2, newest: 6, widest: 3)
's': 9
't': 9
...
Count pairs:
'e' + 's': 9 (newest: 6, widest: 3)
'es' + 't': 9 (newest: 6, widest: 3)
...
Compute scores:
score('e' + 's') = 9 / (11 × 9) = 0.091
score('es' + 't') = 9 / (9 × 9) = 0.111 ← highest score
score('l' + 'o') = 7 / (7 × 9) = 0.111 ← tied
Choose: 'es' + 't' → 'est' (or 'lo' if tied)
```
**Key difference**: WordPiece prioritizes rare combinations over frequent ones.
### Tokenization with WordPiece
Given vocabulary: `['##e', '##s', '##t', 'l', 'o', 'w', 'new', 'est', 'low']`
Tokenize "lowest":
```
Step 1: Find longest matching prefix
'lowest' → 'low' (matches)
Step 2: Find longest match for remainder
'est' → 'est' (matches)
Final: ['low', 'est']
```
**If no match**:
```
Tokenize "unknownword":
'unknownword' → no match
'unknown' → no match
'unkn' → no match
'un' → no match
'u' → no match
→ [UNK]
```
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
# Initialize BERT-style tokenizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
# Normalization (lowercase, accent stripping)
tokenizer.normalizer = BertNormalizer(lowercase=True)
# Pre-tokenization (whitespace + punctuation)
tokenizer.pre_tokenizer = BertPreTokenizer()
# Configure trainer
trainer = WordPieceTrainer(
vocab_size=30522, # BERT vocab size
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##" # BERT uses ##
)
# Train
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("Tokenization works great!")
print(output.tokens) # ['token', '##ization', 'works', 'great', '!']
```
### Subword prefix
**BERT uses `##` prefix**:
```
"unbelievable" → ['un', '##believ', '##able']
```
**Why?**
- Indicates token is a continuation
- Allows reconstruction: remove ##, concatenate
- Helps model distinguish word boundaries
### WordPiece advantages
**Semantic merges**:
- Prioritizes meaningful combinations
- "qu" has high score (always together)
- "qx" has low score (rare combination)
**Better for morphology**:
- Captures affixes: un-, -ing, -ed
- Preserves word stems
**Trade-offs**:
- Slower training than BPE
- More memory (stores vocabulary, not merges)
- Original implementation not open-source (HF reimplementation)
## Unigram
### Algorithm overview
Unigram works backward: start with large vocabulary, remove tokens.
**Training process**:
1. Initialize with large vocabulary (all substrings)
2. Estimate probability of each token (frequency-based)
3. For each token, compute loss increase if removed
4. Remove 10-20% of tokens with lowest loss impact
5. Re-estimate probabilities
6. Repeat until desired vocabulary size
### Probabilistic tokenization
**Unigram assumption**: Each token is independent.
Given vocabulary with probabilities:
```
P('low') = 0.02
P('l') = 0.01
P('o') = 0.015
P('w') = 0.01
P('est') = 0.03
P('e') = 0.02
P('s') = 0.015
P('t') = 0.015
```
Tokenize "lowest":
```
Option 1: ['low', 'est']
P = P('low') × P('est') = 0.02 × 0.03 = 0.0006
Option 2: ['l', 'o', 'w', 'est']
P = 0.01 × 0.015 × 0.01 × 0.03 = 0.000000045
Option 3: ['low', 'e', 's', 't']
P = 0.02 × 0.02 × 0.015 × 0.015 = 0.0000009
Choose option 1 (highest probability)
```
### Viterbi algorithm
Finding best tokenization is expensive (exponential possibilities).
**Viterbi algorithm** (dynamic programming):
```python
def tokenize_viterbi(word, vocab, probs):
n = len(word)
# dp[i] = (best_prob, best_tokens) for word[:i]
dp = [{} for _ in range(n + 1)]
dp[0] = (0.0, []) # log probability
for i in range(1, n + 1):
best_prob = float('-inf')
best_tokens = []
# Try all possible last tokens
for j in range(i):
token = word[j:i]
if token in vocab:
prob = dp[j][0] + log(probs[token])
if prob > best_prob:
best_prob = prob
best_tokens = dp[j][1] + [token]
dp[i] = (best_prob, best_tokens)
return dp[n][1]
```
**Time complexity**: O(n² × vocab_size) vs O(2^n) brute force
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
# Initialize
tokenizer = Tokenizer(Unigram())
# Configure trainer
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>",
max_piece_length=16, # Max token length
n_sub_iterations=2, # EM iterations
shrinking_factor=0.75 # Remove 25% each iteration
)
# Train
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("Tokenization with Unigram")
print(output.tokens) # ['▁Token', 'ization', '▁with', '▁Un', 'igram']
```
### Unigram advantages
**Probabilistic**:
- Multiple valid tokenizations
- Can sample different tokenizations (data augmentation)
**Subword regularization**:
```python
# Sample different tokenizations
for _ in range(3):
tokens = tokenizer.encode("tokenization", is_pretokenized=False).tokens
print(tokens)
# Output (different each time):
# ['token', 'ization']
# ['tok', 'en', 'ization']
# ['token', 'iz', 'ation']
```
**Language-independent**:
- No word boundaries needed
- Works for CJK languages (Chinese, Japanese, Korean)
- Treats input as character stream
**Trade-offs**:
- Slower training (EM algorithm)
- More hyperparameters
- Larger model (stores probabilities)
## Algorithm comparison
### Training speed
| Algorithm | Small (10MB) | Medium (100MB) | Large (1GB) |
|------------|--------------|----------------|-------------|
| BPE | 10-15 sec | 1-2 min | 10-20 min |
| WordPiece | 15-20 sec | 2-3 min | 15-30 min |
| Unigram | 20-30 sec | 3-5 min | 30-60 min |
**Tested on**: 16-core CPU, 30k vocab
### Tokenization quality
Tested on English Wikipedia (perplexity measurement):
| Algorithm | Vocab Size | Tokens/Word | Unknown Rate |
|------------|------------|-------------|--------------|
| BPE | 30k | 1.3 | 0.5% |
| WordPiece | 30k | 1.2 | 1.2% |
| Unigram | 8k | 1.5 | 0.3% |
**Key observations**:
- WordPiece: Slightly better compression
- BPE: Lower unknown rate
- Unigram: Smallest vocab, good coverage
### Compression ratio
Characters per token (higher = better compression):
| Language | BPE (30k) | WordPiece (30k) | Unigram (8k) |
|----------|-----------|-----------------|--------------|
| English | 4.2 | 4.5 | 3.8 |
| Chinese | 2.1 | 2.3 | 2.5 |
| Arabic | 3.5 | 3.8 | 3.2 |
**Best for each**:
- English: WordPiece
- Chinese: Unigram (language-independent)
- Arabic: WordPiece
### Use case recommendations
**BPE** - Best for:
- English language models
- Code (handles symbols well)
- Fast training needed
- **Models**: GPT-2, GPT-3, RoBERTa, BART
**WordPiece** - Best for:
- Masked language modeling (BERT-style)
- Morphologically rich languages
- Semantic understanding tasks
- **Models**: BERT, DistilBERT, ELECTRA
**Unigram** - Best for:
- Multilingual models
- Languages without word boundaries (CJK)
- Data augmentation via subword regularization
- **Models**: T5, ALBERT, XLNet (via SentencePiece)
## Advanced topics
### Handling rare words
**BPE approach**:
```
"antidisestablishmentarianism"
→ ['anti', 'dis', 'establish', 'ment', 'arian', 'ism']
```
**WordPiece approach**:
```
"antidisestablishmentarianism"
→ ['anti', '##dis', '##establish', '##ment', '##arian', '##ism']
```
**Unigram approach**:
```
"antidisestablishmentarianism"
→ ['▁anti', 'dis', 'establish', 'ment', 'arian', 'ism']
```
### Handling numbers
**Challenge**: Infinite number combinations
**BPE solution**: Byte-level (handles any digit sequence)
```python
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = ByteLevel()
# Handles any number
"123456789" → byte-level tokens
```
**WordPiece solution**: Digit pre-tokenization
```python
from tokenizers.pre_tokenizers import Digits
# Split digits individually or as groups
tokenizer.pre_tokenizer = Digits(individual_digits=True)
"123" → ['1', '2', '3']
```
**Unigram solution**: Learns common number patterns
```python
# Learns patterns during training
"2023" → ['202', '3'] or ['20', '23']
```
### Handling case sensitivity
**Lowercase (BERT)**:
```python
from tokenizers.normalizers import Lowercase
tokenizer.normalizer = Lowercase()
"Hello WORLD" → "hello world" → ['hello', 'world']
```
**Preserve case (GPT-2)**:
```python
# No case normalization
tokenizer.normalizer = None
"Hello WORLD" → ['Hello', 'WORLD']
```
**Cased tokens (RoBERTa)**:
```python
# Learns separate tokens for different cases
Vocabulary: ['Hello', 'hello', 'HELLO', 'world', 'WORLD']
```
### Handling emojis and special characters
**Byte-level (GPT-2)**:
```python
tokenizer.pre_tokenizer = ByteLevel()
"Hello 🌍 👋" → byte-level representation (always works)
```
**Unicode normalization**:
```python
from tokenizers.normalizers import NFKC
tokenizer.normalizer = NFKC()
"é" (composed) ↔ "é" (decomposed) → normalized to one form
```
## Troubleshooting
### Issue: Poor subword splitting
**Symptom**:
```
"running" → ['r', 'u', 'n', 'n', 'i', 'n', 'g'] (too granular)
```
**Solutions**:
1. Increase vocabulary size
2. Train longer (more merge iterations)
3. Lower `min_frequency` threshold
### Issue: Too many unknown tokens
**Symptom**:
```
5% of tokens are [UNK]
```
**Solutions**:
1. Increase vocabulary size
2. Use byte-level BPE (no UNK possible)
3. Verify training corpus is representative
### Issue: Inconsistent tokenization
**Symptom**:
```
"running" → ['run', 'ning']
"runner" → ['r', 'u', 'n', 'n', 'e', 'r']
```
**Solutions**:
1. Check normalization consistency
2. Ensure pre-tokenization is deterministic
3. Use Unigram for probabilistic variance
## Best practices
1. **Match algorithm to model architecture**:
- BERT-style → WordPiece
- GPT-style → BPE
- T5-style → Unigram
2. **Use byte-level for multilingual**:
- Handles any Unicode
- No unknown tokens
3. **Test on representative data**:
- Measure compression ratio
- Check unknown token rate
- Inspect sample tokenizations
4. **Version control tokenizers**:
- Save with model
- Document special tokens
- Track vocabulary changes

View file

@ -0,0 +1,637 @@
# Transformers Integration
Complete guide to using HuggingFace Tokenizers with the Transformers library.
## AutoTokenizer
The easiest way to load tokenizers.
### Loading pretrained tokenizers
```python
from transformers import AutoTokenizer
# Load from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Check if using fast tokenizer (Rust-based)
print(tokenizer.is_fast) # True
# Access underlying tokenizers.Tokenizer
if tokenizer.is_fast:
fast_tokenizer = tokenizer.backend_tokenizer
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
```
### Fast vs slow tokenizers
| Feature | Fast (Rust) | Slow (Python) |
|--------------------------|----------------|---------------|
| Speed | 5-10× faster | Baseline |
| Alignment tracking | ✅ Full support | ❌ Limited |
| Batch processing | ✅ Optimized | ⚠️ Slower |
| Offset mapping | ✅ Yes | ❌ No |
| Installation | `tokenizers` | Built-in |
**Always use fast tokenizers when available.**
### Check available tokenizers
```python
from transformers import TOKENIZER_MAPPING
# List all fast tokenizers
for config_class, (slow, fast) in TOKENIZER_MAPPING.items():
if fast is not None:
print(f"{config_class.__name__}: {fast.__name__}")
```
## PreTrainedTokenizerFast
Wrap custom tokenizers for transformers.
### Convert custom tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast
# Train custom tokenizer
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(
vocab_size=30000,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
tokenizer.train(files=["corpus.txt"], trainer=trainer)
# Save tokenizer
tokenizer.save("my-tokenizer.json")
# Wrap for transformers
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_file="my-tokenizer.json",
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]"
)
# Save in transformers format
transformers_tokenizer.save_pretrained("my-tokenizer")
```
**Result**: Directory with `tokenizer.json` + `tokenizer_config.json` + `special_tokens_map.json`
### Use like any transformers tokenizer
```python
# Load
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("my-tokenizer")
# Encode with all transformers features
outputs = tokenizer(
"Hello world",
padding="max_length",
truncation=True,
max_length=128,
return_tensors="pt"
)
print(outputs.keys())
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
```
## Special tokens
### Default special tokens
| Model Family | CLS/BOS | SEP/EOS | PAD | UNK | MASK |
|--------------|---------|---------------|---------|---------|---------|
| BERT | [CLS] | [SEP] | [PAD] | [UNK] | [MASK] |
| GPT-2 | - | <\|endoftext\|> | <\|endoftext\|> | <\|endoftext\|> | - |
| RoBERTa | <s> | </s> | <pad> | <unk> | <mask> |
| T5 | - | </s> | <pad> | <unk> | - |
### Adding special tokens
```python
# Add new special tokens
special_tokens_dict = {
"additional_special_tokens": ["<|image|>", "<|video|>", "<|audio|>"]
}
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
print(f"Added {num_added_tokens} tokens")
# Resize model embeddings
model.resize_token_embeddings(len(tokenizer))
# Use new tokens
text = "This is an image: <|image|>"
tokens = tokenizer.encode(text)
```
### Adding regular tokens
```python
# Add domain-specific tokens
new_tokens = ["COVID-19", "mRNA", "vaccine"]
num_added = tokenizer.add_tokens(new_tokens)
# These are NOT special tokens (can be split if needed)
tokenizer.add_tokens(new_tokens, special_tokens=False)
# These ARE special tokens (never split)
tokenizer.add_tokens(new_tokens, special_tokens=True)
```
## Encoding and decoding
### Basic encoding
```python
# Single sentence
text = "Hello, how are you?"
encoded = tokenizer(text)
print(encoded)
# {'input_ids': [101, 7592, 1010, 2129, 2024, 2017, 1029, 102],
# 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0],
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
```
### Batch encoding
```python
# Multiple sentences
texts = ["Hello world", "How are you?", "I am fine"]
encoded = tokenizer(texts, padding=True, truncation=True, max_length=10)
print(encoded['input_ids'])
# [[101, 7592, 2088, 102, 0, 0, 0, 0, 0, 0],
# [101, 2129, 2024, 2017, 1029, 102, 0, 0, 0, 0],
# [101, 1045, 2572, 2986, 102, 0, 0, 0, 0, 0]]
```
### Return tensors
```python
# Return PyTorch tensors
outputs = tokenizer("Hello world", return_tensors="pt")
print(outputs['input_ids'].shape) # torch.Size([1, 5])
# Return TensorFlow tensors
outputs = tokenizer("Hello world", return_tensors="tf")
# Return NumPy arrays
outputs = tokenizer("Hello world", return_tensors="np")
# Return lists (default)
outputs = tokenizer("Hello world", return_tensors=None)
```
### Decoding
```python
# Decode token IDs
ids = [101, 7592, 2088, 102]
text = tokenizer.decode(ids)
print(text) # "[CLS] hello world [SEP]"
# Skip special tokens
text = tokenizer.decode(ids, skip_special_tokens=True)
print(text) # "hello world"
# Batch decode
batch_ids = [[101, 7592, 102], [101, 2088, 102]]
texts = tokenizer.batch_decode(batch_ids, skip_special_tokens=True)
print(texts) # ["hello", "world"]
```
## Padding and truncation
### Padding strategies
```python
# Pad to max length in batch
tokenizer(texts, padding="longest")
# Pad to model max length
tokenizer(texts, padding="max_length", max_length=128)
# No padding
tokenizer(texts, padding=False)
# Pad to multiple of value (for efficient computation)
tokenizer(texts, padding="max_length", max_length=128, pad_to_multiple_of=8)
# Result: length will be 128 (already multiple of 8)
```
### Truncation strategies
```python
# Truncate to max length
tokenizer(text, truncation=True, max_length=10)
# Only truncate first sequence (for pairs)
tokenizer(text1, text2, truncation="only_first", max_length=20)
# Only truncate second sequence
tokenizer(text1, text2, truncation="only_second", max_length=20)
# Truncate longest first (default for pairs)
tokenizer(text1, text2, truncation="longest_first", max_length=20)
# No truncation (error if too long)
tokenizer(text, truncation=False)
```
### Stride for long documents
```python
# For documents longer than max_length
text = "Very long document " * 1000
# Encode with overlap
encodings = tokenizer(
text,
max_length=512,
stride=128, # Overlap between chunks
truncation=True,
return_overflowing_tokens=True,
return_offsets_mapping=True
)
# Get all chunks
num_chunks = len(encodings['input_ids'])
print(f"Split into {num_chunks} chunks")
# Each chunk overlaps by stride tokens
for i, chunk in enumerate(encodings['input_ids']):
print(f"Chunk {i}: {len(chunk)} tokens")
```
**Use case**: Long document QA, sliding window inference
## Alignment and offsets
### Offset mapping
```python
# Get character offsets for each token
encoded = tokenizer("Hello, world!", return_offsets_mapping=True)
for token, (start, end) in zip(
encoded.tokens(),
encoded['offset_mapping'][0]
):
print(f"{token:10s} → [{start:2d}, {end:2d})")
# Output:
# [CLS] → [ 0, 0)
# Hello → [ 0, 5)
# , → [ 5, 6)
# world → [ 7, 12)
# ! → [12, 13)
# [SEP] → [ 0, 0)
```
### Word IDs
```python
# Get word index for each token
encoded = tokenizer("Hello world", return_offsets_mapping=True)
word_ids = encoded.word_ids()
print(word_ids)
# [None, 0, 1, None]
# None = special token, 0 = first word, 1 = second word
```
**Use case**: Token classification (NER, POS tagging)
### Character to token mapping
```python
text = "Machine learning is awesome"
encoded = tokenizer(text, return_offsets_mapping=True)
# Find token for character position
char_pos = 8 # "l" in "learning"
token_idx = encoded.char_to_token(char_pos)
print(f"Character {char_pos} is in token {token_idx}: {encoded.tokens()[token_idx]}")
# Character 8 is in token 2: learning
```
**Use case**: Question answering (map answer character span to tokens)
### Sequence pairs
```python
# Encode sentence pair
encoded = tokenizer("Question here", "Answer here", return_offsets_mapping=True)
# Get sequence IDs (which sequence each token belongs to)
sequence_ids = encoded.sequence_ids()
print(sequence_ids)
# [None, 0, 0, 0, None, 1, 1, 1, None]
# None = special token, 0 = question, 1 = answer
```
## Model integration
### Use with transformers models
```python
from transformers import AutoModel, AutoTokenizer
import torch
# Load model and tokenizer
model = AutoModel.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Tokenize
text = "Hello world"
inputs = tokenizer(text, return_tensors="pt")
# Forward pass
with torch.no_grad():
outputs = model(**inputs)
# Get embeddings
last_hidden_state = outputs.last_hidden_state
print(last_hidden_state.shape) # [1, seq_len, hidden_size]
```
### Custom model with custom tokenizer
```python
from transformers import BertConfig, BertModel
# Train custom tokenizer
from tokenizers import Tokenizer, models, trainers
tokenizer = Tokenizer(models.BPE())
trainer = trainers.BpeTrainer(vocab_size=30000)
tokenizer.train(files=["data.txt"], trainer=trainer)
# Wrap for transformers
from transformers import PreTrainedTokenizerFast
fast_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
pad_token="[PAD]"
)
# Create model with custom vocab size
config = BertConfig(vocab_size=30000)
model = BertModel(config)
# Use together
inputs = fast_tokenizer("Hello world", return_tensors="pt")
outputs = model(**inputs)
```
### Save and load together
```python
# Save both
model.save_pretrained("my-model")
tokenizer.save_pretrained("my-model")
# Directory structure:
# my-model/
# ├── config.json
# ├── pytorch_model.bin
# ├── tokenizer.json
# ├── tokenizer_config.json
# └── special_tokens_map.json
# Load both
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("my-model")
tokenizer = AutoTokenizer.from_pretrained("my-model")
```
## Advanced features
### Multimodal tokenization
```python
from transformers import AutoTokenizer
# LLaVA-style (image + text)
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf")
# Add image placeholder token
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
# Use in prompt
text = "Describe this image: <image>"
inputs = tokenizer(text, return_tensors="pt")
```
### Template formatting
```python
# Chat template
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi! How can I help?"},
{"role": "user", "content": "What's the weather?"}
]
# Apply chat template (if tokenizer has one)
if hasattr(tokenizer, "apply_chat_template"):
text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(text, return_tensors="pt")
```
### Custom template
```python
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
# Define chat template
tokenizer.chat_template = """
{%- for message in messages %}
{%- if message['role'] == 'system' %}
System: {{ message['content'] }}\\n
{%- elif message['role'] == 'user' %}
User: {{ message['content'] }}\\n
{%- elif message['role'] == 'assistant' %}
Assistant: {{ message['content'] }}\\n
{%- endif %}
{%- endfor %}
Assistant:
"""
# Use template
text = tokenizer.apply_chat_template(messages, tokenize=False)
```
## Performance optimization
### Batch processing
```python
# Process large datasets efficiently
from datasets import load_dataset
dataset = load_dataset("imdb", split="train[:1000]")
# Tokenize in batches
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=512
)
# Map over dataset (batched)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
batch_size=1000,
num_proc=4 # Parallel processing
)
```
### Caching
```python
# Enable caching for repeated tokenization
tokenizer = AutoTokenizer.from_pretrained(
"bert-base-uncased",
use_fast=True,
cache_dir="./cache" # Cache tokenizer files
)
# Tokenize with caching
from functools import lru_cache
@lru_cache(maxsize=10000)
def cached_tokenize(text):
return tuple(tokenizer.encode(text))
# Reuses cached results for repeated inputs
```
### Memory efficiency
```python
# For very large datasets, use streaming
from datasets import load_dataset
dataset = load_dataset("pile", split="train", streaming=True)
def process_batch(batch):
# Tokenize
tokens = tokenizer(batch["text"], truncation=True, max_length=512)
# Process tokens...
return tokens
# Process in chunks (memory efficient)
for batch in dataset.batch(batch_size=1000):
processed = process_batch(batch)
```
## Troubleshooting
### Issue: Tokenizer not fast
**Symptom**:
```python
tokenizer.is_fast # False
```
**Solution**: Install tokenizers library
```bash
pip install tokenizers
```
### Issue: Special tokens not working
**Symptom**: Special tokens are split into subwords
**Solution**: Add as special tokens, not regular tokens
```python
# Wrong
tokenizer.add_tokens(["<|image|>"])
# Correct
tokenizer.add_special_tokens({"additional_special_tokens": ["<|image|>"]})
```
### Issue: Offset mapping not available
**Symptom**:
```python
tokenizer("text", return_offsets_mapping=True)
# Error: return_offsets_mapping not supported
```
**Solution**: Use fast tokenizer
```python
from transformers import AutoTokenizer
# Load fast version
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
```
### Issue: Padding inconsistent
**Symptom**: Some sequences padded, others not
**Solution**: Specify padding strategy
```python
# Explicit padding
tokenizer(
texts,
padding="max_length", # or "longest"
max_length=128
)
```
## Best practices
1. **Always use fast tokenizers**:
- 5-10× faster
- Full alignment tracking
- Better batch processing
2. **Save tokenizer with model**:
- Ensures reproducibility
- Prevents version mismatches
3. **Use batch processing for datasets**:
- Tokenize with `.map(batched=True)`
- Set `num_proc` for parallelism
4. **Enable caching for repeated inputs**:
- Use `lru_cache` for inference
- Cache tokenizer files with `cache_dir`
5. **Handle special tokens properly**:
- Use `add_special_tokens()` for never-split tokens
- Resize embeddings after adding tokens
6. **Test alignment for downstream tasks**:
- Verify `offset_mapping` is correct
- Test `char_to_token()` on samples
7. **Version control tokenizer config**:
- Save `tokenizer_config.json`
- Document custom templates
- Track vocabulary changes

View file

@ -0,0 +1,723 @@
# Tokenization Pipeline Components
Complete guide to normalizers, pre-tokenizers, models, post-processors, and decoders.
## Pipeline overview
**Full tokenization pipeline**:
```
Raw Text
Normalization (cleaning, lowercasing)
Pre-tokenization (split into words)
Model (apply BPE/WordPiece/Unigram)
Post-processing (add special tokens)
Token IDs
```
**Decoding reverses the process**:
```
Token IDs
Decoder (handle special encodings)
Raw Text
```
## Normalizers
Clean and standardize input text.
### Common normalizers
**Lowercase**:
```python
from tokenizers.normalizers import Lowercase
tokenizer.normalizer = Lowercase()
# Input: "Hello WORLD"
# Output: "hello world"
```
**Unicode normalization**:
```python
from tokenizers.normalizers import NFD, NFC, NFKD, NFKC
# NFD: Canonical decomposition
tokenizer.normalizer = NFD()
# "é" → "e" + "́" (separate characters)
# NFC: Canonical composition (default)
tokenizer.normalizer = NFC()
# "e" + "́" → "é" (composed)
# NFKD: Compatibility decomposition
tokenizer.normalizer = NFKD()
# "fi" → "f" + "i"
# NFKC: Compatibility composition
tokenizer.normalizer = NFKC()
# Most aggressive normalization
```
**Strip accents**:
```python
from tokenizers.normalizers import StripAccents
tokenizer.normalizer = StripAccents()
# Input: "café"
# Output: "cafe"
```
**Whitespace handling**:
```python
from tokenizers.normalizers import Strip, StripAccents
# Remove leading/trailing whitespace
tokenizer.normalizer = Strip()
# Input: " hello "
# Output: "hello"
```
**Replace patterns**:
```python
from tokenizers.normalizers import Replace
# Replace newlines with spaces
tokenizer.normalizer = Replace("\\n", " ")
# Input: "hello\\nworld"
# Output: "hello world"
```
### Combining normalizers
```python
from tokenizers.normalizers import Sequence, NFD, Lowercase, StripAccents
# BERT-style normalization
tokenizer.normalizer = Sequence([
NFD(), # Unicode decomposition
Lowercase(), # Convert to lowercase
StripAccents() # Remove accents
])
# Input: "Café au Lait"
# After NFD: "Café au Lait" (e + ́)
# After Lowercase: "café au lait"
# After StripAccents: "cafe au lait"
```
### Use case examples
**Case-insensitive model (BERT)**:
```python
from tokenizers.normalizers import BertNormalizer
# All-in-one BERT normalization
tokenizer.normalizer = BertNormalizer(
clean_text=True, # Remove control characters
handle_chinese_chars=True, # Add spaces around Chinese
strip_accents=True, # Remove accents
lowercase=True # Lowercase
)
```
**Case-sensitive model (GPT-2)**:
```python
# Minimal normalization
tokenizer.normalizer = NFC() # Only normalize Unicode
```
**Multilingual (mBERT)**:
```python
# Preserve scripts, normalize form
tokenizer.normalizer = NFKC()
```
## Pre-tokenizers
Split text into word-like units before tokenization.
### Whitespace splitting
```python
from tokenizers.pre_tokenizers import Whitespace
tokenizer.pre_tokenizer = Whitespace()
# Input: "Hello world! How are you?"
# Output: [("Hello", (0, 5)), ("world!", (6, 12)), ("How", (13, 16)), ("are", (17, 20)), ("you?", (21, 25))]
```
### Punctuation isolation
```python
from tokenizers.pre_tokenizers import Punctuation
tokenizer.pre_tokenizer = Punctuation()
# Input: "Hello, world!"
# Output: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
```
### Byte-level (GPT-2)
```python
from tokenizers.pre_tokenizers import ByteLevel
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
# Input: "Hello world"
# Output: Byte-level tokens with Ġ prefix for spaces
# [("ĠHello", ...), ("Ġworld", ...)]
```
**Key feature**: Handles ALL Unicode characters (256 byte combinations)
### Metaspace (SentencePiece)
```python
from tokenizers.pre_tokenizers import Metaspace
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
# Input: "Hello world"
# Output: [("▁Hello", ...), ("▁world", ...)]
```
**Used by**: T5, ALBERT (via SentencePiece)
### Digits splitting
```python
from tokenizers.pre_tokenizers import Digits
# Split digits individually
tokenizer.pre_tokenizer = Digits(individual_digits=True)
# Input: "Room 123"
# Output: [("Room", ...), ("1", ...), ("2", ...), ("3", ...)]
# Keep digits together
tokenizer.pre_tokenizer = Digits(individual_digits=False)
# Input: "Room 123"
# Output: [("Room", ...), ("123", ...)]
```
### BERT pre-tokenizer
```python
from tokenizers.pre_tokenizers import BertPreTokenizer
tokenizer.pre_tokenizer = BertPreTokenizer()
# Splits on whitespace and punctuation, preserves CJK
# Input: "Hello, 世界!"
# Output: [("Hello", ...), (",", ...), ("世", ...), ("界", ...), ("!", ...)]
```
### Combining pre-tokenizers
```python
from tokenizers.pre_tokenizers import Sequence, Whitespace, Punctuation
tokenizer.pre_tokenizer = Sequence([
Whitespace(), # Split on whitespace first
Punctuation() # Then isolate punctuation
])
# Input: "Hello, world!"
# After Whitespace: [("Hello,", ...), ("world!", ...)]
# After Punctuation: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
```
### Pre-tokenizer comparison
| Pre-tokenizer | Use Case | Example |
|-------------------|---------------------------------|--------------------------------------------|
| Whitespace | Simple English | "Hello world" → ["Hello", "world"] |
| Punctuation | Isolate symbols | "world!" → ["world", "!"] |
| ByteLevel | Multilingual, emojis | "🌍" → byte tokens |
| Metaspace | SentencePiece-style | "Hello" → ["▁Hello"] |
| BertPreTokenizer | BERT-style (CJK aware) | "世界" → ["世", "界"] |
| Digits | Handle numbers | "123" → ["1", "2", "3"] or ["123"] |
## Models
Core tokenization algorithms.
### BPE Model
```python
from tokenizers.models import BPE
model = BPE(
vocab=None, # Or provide pre-built vocab
merges=None, # Or provide merge rules
unk_token="[UNK]", # Unknown token
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False # Keep unknown tokens separate
)
tokenizer = Tokenizer(model)
```
**Parameters**:
- `vocab`: Dict of token → id
- `merges`: List of merge rules `["a b", "ab c"]`
- `unk_token`: Token for unknown words
- `continuing_subword_prefix`: Prefix for subwords (empty for GPT-2)
- `end_of_word_suffix`: Suffix for last subword (empty for GPT-2)
### WordPiece Model
```python
from tokenizers.models import WordPiece
model = WordPiece(
vocab=None,
unk_token="[UNK]",
max_input_chars_per_word=100, # Max word length
continuing_subword_prefix="##" # BERT-style prefix
)
tokenizer = Tokenizer(model)
```
**Key difference**: Uses `##` prefix for continuing subwords.
### Unigram Model
```python
from tokenizers.models import Unigram
model = Unigram(
vocab=None, # List of (token, score) tuples
unk_id=0, # ID for unknown token
byte_fallback=False # Fall back to bytes if no match
)
tokenizer = Tokenizer(model)
```
**Probabilistic**: Selects tokenization with highest probability.
### WordLevel Model
```python
from tokenizers.models import WordLevel
# Simple word-to-ID mapping (no subwords)
model = WordLevel(
vocab=None,
unk_token="[UNK]"
)
tokenizer = Tokenizer(model)
```
**Warning**: Requires huge vocabulary (one token per word).
## Post-processors
Add special tokens and format output.
### Template processing
**BERT-style** (`[CLS] sentence [SEP]`):
```python
from tokenizers.processors import TemplateProcessing
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", 101),
("[SEP]", 102),
],
)
# Single sentence
output = tokenizer.encode("Hello world")
# [101, ..., 102] ([CLS] hello world [SEP])
# Sentence pair
output = tokenizer.encode("Hello", "world")
# [101, ..., 102, ..., 102] ([CLS] hello [SEP] world [SEP])
```
**GPT-2 style** (`sentence <|endoftext|>`):
```python
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[
("<|endoftext|>", 50256),
],
)
```
**RoBERTa style** (`<s> sentence </s>`):
```python
tokenizer.post_processor = TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[
("<s>", 0),
("</s>", 2),
],
)
```
**T5 style** (no special tokens):
```python
# T5 doesn't add special tokens via post-processor
tokenizer.post_processor = None
```
### RobertaProcessing
```python
from tokenizers.processors import RobertaProcessing
tokenizer.post_processor = RobertaProcessing(
sep=("</s>", 2),
cls=("<s>", 0),
add_prefix_space=True, # Add space before first token
trim_offsets=True # Trim leading space from offsets
)
```
### ByteLevelProcessing
```python
from tokenizers.processors import ByteLevel as ByteLevelProcessing
tokenizer.post_processor = ByteLevelProcessing(
trim_offsets=True # Remove Ġ from offsets
)
```
## Decoders
Convert token IDs back to text.
### ByteLevel decoder
```python
from tokenizers.decoders import ByteLevel
tokenizer.decoder = ByteLevel()
# Handles byte-level tokens
# ["ĠHello", "Ġworld"] → "Hello world"
```
### WordPiece decoder
```python
from tokenizers.decoders import WordPiece
tokenizer.decoder = WordPiece(prefix="##")
# Removes ## prefix and concatenates
# ["token", "##ization"] → "tokenization"
```
### Metaspace decoder
```python
from tokenizers.decoders import Metaspace
tokenizer.decoder = Metaspace(replacement="▁", add_prefix_space=True)
# Converts ▁ back to spaces
# ["▁Hello", "▁world"] → "Hello world"
```
### BPEDecoder
```python
from tokenizers.decoders import BPEDecoder
tokenizer.decoder = BPEDecoder(suffix="</w>")
# Removes suffix and concatenates
# ["token", "ization</w>"] → "tokenization"
```
### Sequence decoder
```python
from tokenizers.decoders import Sequence, ByteLevel, Strip
tokenizer.decoder = Sequence([
ByteLevel(), # Decode byte-level first
Strip(' ', 1, 1) # Strip leading/trailing spaces
])
```
## Complete pipeline examples
### BERT tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
from tokenizers.processors import TemplateProcessing
from tokenizers.decoders import WordPiece as WordPieceDecoder
# Model
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
# Normalization
tokenizer.normalizer = BertNormalizer(lowercase=True)
# Pre-tokenization
tokenizer.pre_tokenizer = BertPreTokenizer()
# Post-processing
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
)
# Decoder
tokenizer.decoder = WordPieceDecoder(prefix="##")
# Enable padding
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
# Enable truncation
tokenizer.enable_truncation(max_length=512)
```
### GPT-2 tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.normalizers import NFC
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.processors import TemplateProcessing
# Model
tokenizer = Tokenizer(BPE())
# Normalization (minimal)
tokenizer.normalizer = NFC()
# Byte-level pre-tokenization
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
# Post-processing
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[("<|endoftext|>", 50256)],
)
# Byte-level decoder
tokenizer.decoder = ByteLevelDecoder()
```
### T5 tokenizer (SentencePiece-style)
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import Metaspace
from tokenizers.decoders import Metaspace as MetaspaceDecoder
# Model
tokenizer = Tokenizer(Unigram())
# Normalization
tokenizer.normalizer = NFKC()
# Metaspace pre-tokenization
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
# No post-processing (T5 doesn't add CLS/SEP)
tokenizer.post_processor = None
# Metaspace decoder
tokenizer.decoder = MetaspaceDecoder(replacement="▁", add_prefix_space=True)
```
## Alignment tracking
Track token positions in original text.
### Basic alignment
```python
text = "Hello, world!"
output = tokenizer.encode(text)
for token, (start, end) in zip(output.tokens, output.offsets):
print(f"{token:10s} → [{start:2d}, {end:2d}): {text[start:end]!r}")
# Output:
# [CLS] → [ 0, 0): ''
# hello → [ 0, 5): 'Hello'
# , → [ 5, 6): ','
# world → [ 7, 12): 'world'
# ! → [12, 13): '!'
# [SEP] → [ 0, 0): ''
```
### Word-level alignment
```python
# Get word_ids (which word each token belongs to)
encoding = tokenizer.encode("Hello world")
word_ids = encoding.word_ids
print(word_ids)
# [None, 0, 0, 1, None]
# None = special token, 0 = first word, 1 = second word
```
**Use case**: Token classification (NER)
```python
# Align predictions to words
predictions = ["O", "B-PER", "I-PER", "O", "O"]
word_predictions = {}
for token_idx, word_idx in enumerate(encoding.word_ids):
if word_idx is not None and word_idx not in word_predictions:
word_predictions[word_idx] = predictions[token_idx]
print(word_predictions)
# {0: "B-PER", 1: "O"} # First word is PERSON, second is OTHER
```
### Span alignment
```python
# Find token span for character span
text = "Machine learning is awesome"
char_start, char_end = 8, 16 # "learning"
encoding = tokenizer.encode(text)
# Find token span
token_start = encoding.char_to_token(char_start)
token_end = encoding.char_to_token(char_end - 1) + 1
print(f"Tokens {token_start}:{token_end} = {encoding.tokens[token_start:token_end]}")
# Tokens 2:3 = ['learning']
```
**Use case**: Question answering (extract answer span)
## Custom components
### Custom normalizer
```python
from tokenizers import NormalizedString, Normalizer
class CustomNormalizer:
def normalize(self, normalized: NormalizedString):
# Custom normalization logic
normalized.lowercase()
normalized.replace(" ", " ") # Replace double spaces
# Use custom normalizer
tokenizer.normalizer = CustomNormalizer()
```
### Custom pre-tokenizer
```python
from tokenizers import PreTokenizedString
class CustomPreTokenizer:
def pre_tokenize(self, pretok: PreTokenizedString):
# Custom pre-tokenization logic
pretok.split(lambda i, char: char.isspace())
tokenizer.pre_tokenizer = CustomPreTokenizer()
```
## Troubleshooting
### Issue: Misaligned offsets
**Symptom**: Offsets don't match original text
```python
text = " hello" # Leading spaces
offsets = [(0, 5)] # Expects " hel"
```
**Solution**: Check normalization strips spaces
```python
# Preserve offsets
tokenizer.normalizer = Sequence([
Strip(), # This changes offsets!
])
# Use trim_offsets in post-processor instead
tokenizer.post_processor = ByteLevelProcessing(trim_offsets=True)
```
### Issue: Special tokens not added
**Symptom**: No [CLS] or [SEP] in output
**Solution**: Check post-processor is set
```python
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
)
```
### Issue: Incorrect decoding
**Symptom**: Decoded text has ## or ▁
**Solution**: Set correct decoder
```python
# For WordPiece
tokenizer.decoder = WordPieceDecoder(prefix="##")
# For SentencePiece
tokenizer.decoder = MetaspaceDecoder(replacement="▁")
```
## Best practices
1. **Match pipeline to model architecture**:
- BERT → BertNormalizer + BertPreTokenizer + WordPiece
- GPT-2 → NFC + ByteLevel + BPE
- T5 → NFKC + Metaspace + Unigram
2. **Test pipeline on sample inputs**:
- Check normalization doesn't over-normalize
- Verify pre-tokenization splits correctly
- Ensure decoding reconstructs text
3. **Preserve alignment for downstream tasks**:
- Use `trim_offsets` instead of stripping in normalizer
- Test `char_to_token()` on sample spans
4. **Document your pipeline**:
- Save complete tokenizer config
- Document special tokens
- Note any custom components

View file

@ -0,0 +1,565 @@
# Training Custom Tokenizers
Complete guide to training tokenizers from scratch.
## Training workflow
### Step 1: Choose tokenization algorithm
**Decision tree**:
- **GPT-style model** → BPE
- **BERT-style model** → WordPiece
- **Multilingual/No word boundaries** → Unigram
### Step 2: Prepare training data
```python
# Option 1: From files
files = ["train.txt", "validation.txt"]
# Option 2: From Python list
texts = [
"This is the first sentence.",
"This is the second sentence.",
# ... more texts
]
# Option 3: From dataset iterator
from datasets import load_dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i:i + batch_size]["text"]
```
### Step 3: Initialize tokenizer
**BPE example**:
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
trainer = BpeTrainer(
vocab_size=50000,
min_frequency=2,
special_tokens=["<|endoftext|>", "<|padding|>"],
show_progress=True
)
```
**WordPiece example**:
```python
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
tokenizer.normalizer = BertNormalizer(lowercase=True)
tokenizer.pre_tokenizer = BertPreTokenizer()
trainer = WordPieceTrainer(
vocab_size=30522,
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##",
show_progress=True
)
```
**Unigram example**:
```python
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
tokenizer = Tokenizer(Unigram())
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>", "<pad>"],
unk_token="<unk>",
show_progress=True
)
```
### Step 4: Train
```python
# From files
tokenizer.train(files=files, trainer=trainer)
# From iterator (recommended for large datasets)
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer,
length=len(dataset) # Optional, for progress bar
)
```
**Training time** (30k vocab on 16-core CPU):
- 10 MB: 15-30 seconds
- 100 MB: 1-3 minutes
- 1 GB: 15-30 minutes
- 10 GB: 2-4 hours
### Step 5: Add post-processing
```python
from tokenizers.processors import TemplateProcessing
# BERT-style
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", tokenizer.token_to_id("[CLS]")),
("[SEP]", tokenizer.token_to_id("[SEP]")),
],
)
# GPT-2 style
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[
("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
],
)
```
### Step 6: Save
```python
# Save to JSON
tokenizer.save("my-tokenizer.json")
# Save to directory (for transformers)
tokenizer.save("my-tokenizer-dir/tokenizer.json")
# Convert to transformers format
from transformers import PreTrainedTokenizerFast
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]"
)
transformers_tokenizer.save_pretrained("my-tokenizer-dir")
```
## Trainer configuration
### BpeTrainer parameters
```python
from tokenizers.trainers import BpeTrainer
trainer = BpeTrainer(
vocab_size=30000, # Target vocabulary size
min_frequency=2, # Minimum frequency for merges
special_tokens=["[UNK]"], # Special tokens (added first)
limit_alphabet=1000, # Limit initial alphabet size
initial_alphabet=[], # Pre-defined initial characters
show_progress=True, # Show progress bar
continuing_subword_prefix="", # Prefix for continuing subwords
end_of_word_suffix="" # Suffix for end of words
)
```
**Parameter tuning**:
- **vocab_size**: Start with 30k for English, 50k for multilingual
- **min_frequency**: 2-5 for large corpora, 1 for small
- **limit_alphabet**: Reduce for non-English (CJK languages)
### WordPieceTrainer parameters
```python
from tokenizers.trainers import WordPieceTrainer
trainer = WordPieceTrainer(
vocab_size=30522, # BERT uses 30,522
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
limit_alphabet=1000,
continuing_subword_prefix="##", # BERT-style prefix
show_progress=True
)
```
### UnigramTrainer parameters
```python
from tokenizers.trainers import UnigramTrainer
trainer = UnigramTrainer(
vocab_size=8000, # Typically smaller than BPE/WordPiece
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>",
max_piece_length=16, # Maximum token length
n_sub_iterations=2, # EM algorithm iterations
shrinking_factor=0.75, # Vocabulary reduction rate
show_progress=True
)
```
## Training from large datasets
### Memory-efficient training
```python
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
# Load dataset
dataset = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)
# Create iterator (yields batches)
def batch_iterator(batch_size=1000):
batch = []
for sample in dataset:
batch.append(sample["text"])
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch
# Initialize tokenizer
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=50000, special_tokens=["<|endoftext|>"])
# Train (memory efficient - streams data)
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer
)
```
**Memory usage**: ~200 MB (vs 10+ GB loading full dataset)
### Multi-file training
```python
import glob
# Find all training files
files = glob.glob("data/train/*.txt")
print(f"Training on {len(files)} files")
# Train on all files
tokenizer.train(files=files, trainer=trainer)
```
### Parallel training (multi-processing)
```python
from multiprocessing import Pool, cpu_count
import os
def train_shard(shard_files):
"""Train tokenizer on a shard of files."""
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=50000)
tokenizer.train(files=shard_files, trainer=trainer)
return tokenizer.get_vocab()
# Split files into shards
num_shards = cpu_count()
file_shards = [files[i::num_shards] for i in range(num_shards)]
# Train shards in parallel
with Pool(num_shards) as pool:
vocab_shards = pool.map(train_shard, file_shards)
# Merge vocabularies (custom logic needed)
# This is a simplified example - real implementation would merge intelligently
final_vocab = {}
for vocab in vocab_shards:
final_vocab.update(vocab)
```
## Domain-specific tokenizers
### Code tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.normalizers import Sequence, NFC
# Code-optimized configuration
tokenizer = Tokenizer(BPE())
# Minimal normalization (preserve case, whitespace)
tokenizer.normalizer = NFC() # Only normalize Unicode
# Byte-level pre-tokenization (handles all characters)
tokenizer.pre_tokenizer = ByteLevel()
# Train on code corpus
trainer = BpeTrainer(
vocab_size=50000,
special_tokens=["<|endoftext|>", "<|pad|>"],
min_frequency=2
)
tokenizer.train(files=["code_corpus.txt"], trainer=trainer)
```
### Medical/scientific tokenizer
```python
# Preserve case and special characters
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence
tokenizer = Tokenizer(BPE())
# Minimal normalization
tokenizer.normalizer = NFKC()
# Preserve medical terms
tokenizer.pre_tokenizer = Sequence([
Whitespace(),
Punctuation(behavior="isolated") # Keep punctuation separate
])
trainer = BpeTrainer(
vocab_size=50000,
special_tokens=["[UNK]", "[CLS]", "[SEP]"],
min_frequency=3 # Higher threshold for rare medical terms
)
tokenizer.train(files=["pubmed_corpus.txt"], trainer=trainer)
```
### Multilingual tokenizer
```python
# Handle multiple scripts
from tokenizers.normalizers import NFKC, Lowercase, Sequence
tokenizer = Tokenizer(BPE())
# Normalize but don't lowercase (preserves script differences)
tokenizer.normalizer = NFKC()
# Byte-level handles all Unicode
from tokenizers.pre_tokenizers import ByteLevel
tokenizer.pre_tokenizer = ByteLevel()
trainer = BpeTrainer(
vocab_size=100000, # Larger vocab for multiple languages
special_tokens=["<unk>", "<s>", "</s>"],
limit_alphabet=None # No limit (handles all scripts)
)
# Train on multilingual corpus
tokenizer.train(files=["multilingual_corpus.txt"], trainer=trainer)
```
## Vocabulary size selection
### Guidelines by task
| Task | Recommended Vocab Size | Rationale |
|-----------------------|------------------------|-----------|
| English (monolingual) | 30,000 - 50,000 | Balanced coverage |
| Multilingual | 50,000 - 250,000 | More languages = more tokens |
| Code | 30,000 - 50,000 | Similar to English |
| Domain-specific | 10,000 - 30,000 | Smaller, focused vocabulary |
| Character-level tasks | 1,000 - 5,000 | Only characters + subwords |
### Vocabulary size impact
**Small vocab (10k)**:
- Pros: Faster training, smaller model, less memory
- Cons: More tokens per sentence, worse OOV handling
**Medium vocab (30k-50k)**:
- Pros: Good balance, standard choice
- Cons: None (recommended default)
**Large vocab (100k+)**:
- Pros: Fewer tokens per sentence, better OOV
- Cons: Slower training, larger embedding table
### Empirical testing
```python
# Train multiple tokenizers with different vocab sizes
vocab_sizes = [10000, 30000, 50000, 100000]
for vocab_size in vocab_sizes:
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=vocab_size)
tokenizer.train(files=["sample.txt"], trainer=trainer)
# Evaluate on test set
test_text = "Test sentence for evaluation..."
tokens = tokenizer.encode(test_text).ids
print(f"Vocab: {vocab_size:6d} | Tokens: {len(tokens):3d} | Avg: {len(test_text)/len(tokens):.2f} chars/token")
# Example output:
# Vocab: 10000 | Tokens: 12 | Avg: 2.33 chars/token
# Vocab: 30000 | Tokens: 8 | Avg: 3.50 chars/token
# Vocab: 50000 | Tokens: 7 | Avg: 4.00 chars/token
# Vocab: 100000 | Tokens: 6 | Avg: 4.67 chars/token
```
## Testing tokenizer quality
### Coverage test
```python
# Test on held-out data
test_corpus = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
total_tokens = 0
unk_tokens = 0
unk_id = tokenizer.token_to_id("[UNK]")
for text in test_corpus["text"]:
if text.strip():
encoding = tokenizer.encode(text)
total_tokens += len(encoding.ids)
unk_tokens += encoding.ids.count(unk_id)
unk_rate = unk_tokens / total_tokens
print(f"Unknown token rate: {unk_rate:.2%}")
# Good quality: <1% unknown tokens
# Acceptable: 1-5%
# Poor: >5%
```
### Compression test
```python
# Measure tokenization efficiency
import numpy as np
token_lengths = []
for text in test_corpus["text"][:1000]:
if text.strip():
encoding = tokenizer.encode(text)
chars_per_token = len(text) / len(encoding.ids)
token_lengths.append(chars_per_token)
avg_chars_per_token = np.mean(token_lengths)
print(f"Average characters per token: {avg_chars_per_token:.2f}")
# Good: 4-6 chars/token (English)
# Acceptable: 3-4 chars/token
# Poor: <3 chars/token (under-compression)
```
### Semantic test
```python
# Manually inspect tokenization of common words/phrases
test_phrases = [
"tokenization",
"machine learning",
"artificial intelligence",
"preprocessing",
"hello world"
]
for phrase in test_phrases:
tokens = tokenizer.encode(phrase).tokens
print(f"{phrase:25s} → {tokens}")
# Good tokenization:
# tokenization → ['token', 'ization']
# machine learning → ['machine', 'learning']
# artificial intelligence → ['artificial', 'intelligence']
```
## Troubleshooting
### Issue: Training too slow
**Solutions**:
1. Reduce vocabulary size
2. Increase `min_frequency`
3. Use `limit_alphabet` to reduce initial alphabet
4. Train on subset first
```python
# Fast training configuration
trainer = BpeTrainer(
vocab_size=20000, # Smaller vocab
min_frequency=5, # Higher threshold
limit_alphabet=500, # Limit alphabet
show_progress=True
)
```
### Issue: High unknown token rate
**Solutions**:
1. Increase vocabulary size
2. Decrease `min_frequency`
3. Check normalization (might be too aggressive)
```python
# Better coverage configuration
trainer = BpeTrainer(
vocab_size=50000, # Larger vocab
min_frequency=1, # Lower threshold
)
```
### Issue: Poor quality tokenization
**Solutions**:
1. Verify normalization matches your use case
2. Check pre-tokenization splits correctly
3. Ensure training data is representative
4. Try different algorithm (BPE vs WordPiece vs Unigram)
```python
# Debug tokenization pipeline
text = "Sample text to debug"
# Check normalization
normalized = tokenizer.normalizer.normalize_str(text)
print(f"Normalized: {normalized}")
# Check pre-tokenization
pre_tokens = tokenizer.pre_tokenizer.pre_tokenize_str(text)
print(f"Pre-tokens: {pre_tokens}")
# Check final tokenization
tokens = tokenizer.encode(text).tokens
print(f"Tokens: {tokens}")
```
## Best practices
1. **Use representative training data** - Match your target domain
2. **Start with standard configs** - BERT WordPiece or GPT-2 BPE
3. **Test on held-out data** - Measure unknown token rate
4. **Iterate on vocabulary size** - Test 30k, 50k, 100k
5. **Save tokenizer with model** - Ensure reproducibility
6. **Version your tokenizers** - Track changes for reproducibility
7. **Document special tokens** - Critical for model training

View file

@ -0,0 +1,493 @@
---
name: evaluating-llms-harness
description: Evaluates LLMs across 60+ academic benchmarks (MMLU, HumanEval, GSM8K, TruthfulQA, HellaSwag). Use when benchmarking model quality, comparing models, reporting academic results, or tracking training progress. Industry standard used by EleutherAI, HuggingFace, and major labs. Supports HuggingFace, vLLM, APIs.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [lm-eval, transformers, vllm]
metadata:
hermes:
tags: [Evaluation, LM Evaluation Harness, Benchmarking, MMLU, HumanEval, GSM8K, EleutherAI, Model Quality, Academic Benchmarks, Industry Standard]
---
# lm-evaluation-harness - LLM Benchmarking
## Quick start
lm-evaluation-harness evaluates LLMs across 60+ academic benchmarks using standardized prompts and metrics.
**Installation**:
```bash
pip install lm-eval
```
**Evaluate any HuggingFace model**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu,gsm8k,hellaswag \
--device cuda:0 \
--batch_size 8
```
**View available tasks**:
```bash
lm_eval --tasks list
```
## Common workflows
### Workflow 1: Standard benchmark evaluation
Evaluate model on core benchmarks (MMLU, GSM8K, HumanEval).
Copy this checklist:
```
Benchmark Evaluation:
- [ ] Step 1: Choose benchmark suite
- [ ] Step 2: Configure model
- [ ] Step 3: Run evaluation
- [ ] Step 4: Analyze results
```
**Step 1: Choose benchmark suite**
**Core reasoning benchmarks**:
- **MMLU** (Massive Multitask Language Understanding) - 57 subjects, multiple choice
- **GSM8K** - Grade school math word problems
- **HellaSwag** - Common sense reasoning
- **TruthfulQA** - Truthfulness and factuality
- **ARC** (AI2 Reasoning Challenge) - Science questions
**Code benchmarks**:
- **HumanEval** - Python code generation (164 problems)
- **MBPP** (Mostly Basic Python Problems) - Python coding
**Standard suite** (recommended for model releases):
```bash
--tasks mmlu,gsm8k,hellaswag,truthfulqa,arc_challenge
```
**Step 2: Configure model**
**HuggingFace model**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf,dtype=bfloat16 \
--tasks mmlu \
--device cuda:0 \
--batch_size auto # Auto-detect optimal batch size
```
**Quantized model (4-bit/8-bit)**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf,load_in_4bit=True \
--tasks mmlu \
--device cuda:0
```
**Custom checkpoint**:
```bash
lm_eval --model hf \
--model_args pretrained=/path/to/my-model,tokenizer=/path/to/tokenizer \
--tasks mmlu \
--device cuda:0
```
**Step 3: Run evaluation**
```bash
# Full MMLU evaluation (57 subjects)
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu \
--num_fewshot 5 \ # 5-shot evaluation (standard)
--batch_size 8 \
--output_path results/ \
--log_samples # Save individual predictions
# Multiple benchmarks at once
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu,gsm8k,hellaswag,truthfulqa,arc_challenge \
--num_fewshot 5 \
--batch_size 8 \
--output_path results/llama2-7b-eval.json
```
**Step 4: Analyze results**
Results saved to `results/llama2-7b-eval.json`:
```json
{
"results": {
"mmlu": {
"acc": 0.459,
"acc_stderr": 0.004
},
"gsm8k": {
"exact_match": 0.142,
"exact_match_stderr": 0.006
},
"hellaswag": {
"acc_norm": 0.765,
"acc_norm_stderr": 0.004
}
},
"config": {
"model": "hf",
"model_args": "pretrained=meta-llama/Llama-2-7b-hf",
"num_fewshot": 5
}
}
```
### Workflow 2: Track training progress
Evaluate checkpoints during training.
```
Training Progress Tracking:
- [ ] Step 1: Set up periodic evaluation
- [ ] Step 2: Choose quick benchmarks
- [ ] Step 3: Automate evaluation
- [ ] Step 4: Plot learning curves
```
**Step 1: Set up periodic evaluation**
Evaluate every N training steps:
```bash
#!/bin/bash
# eval_checkpoint.sh
CHECKPOINT_DIR=$1
STEP=$2
lm_eval --model hf \
--model_args pretrained=$CHECKPOINT_DIR/checkpoint-$STEP \
--tasks gsm8k,hellaswag \
--num_fewshot 0 \ # 0-shot for speed
--batch_size 16 \
--output_path results/step-$STEP.json
```
**Step 2: Choose quick benchmarks**
Fast benchmarks for frequent evaluation:
- **HellaSwag**: ~10 minutes on 1 GPU
- **GSM8K**: ~5 minutes
- **PIQA**: ~2 minutes
Avoid for frequent eval (too slow):
- **MMLU**: ~2 hours (57 subjects)
- **HumanEval**: Requires code execution
**Step 3: Automate evaluation**
Integrate with training script:
```python
# In training loop
if step % eval_interval == 0:
model.save_pretrained(f"checkpoints/step-{step}")
# Run evaluation
os.system(f"./eval_checkpoint.sh checkpoints step-{step}")
```
Or use PyTorch Lightning callbacks:
```python
from pytorch_lightning import Callback
class EvalHarnessCallback(Callback):
def on_validation_epoch_end(self, trainer, pl_module):
step = trainer.global_step
checkpoint_path = f"checkpoints/step-{step}"
# Save checkpoint
trainer.save_checkpoint(checkpoint_path)
# Run lm-eval
os.system(f"lm_eval --model hf --model_args pretrained={checkpoint_path} ...")
```
**Step 4: Plot learning curves**
```python
import json
import matplotlib.pyplot as plt
# Load all results
steps = []
mmlu_scores = []
for file in sorted(glob.glob("results/step-*.json")):
with open(file) as f:
data = json.load(f)
step = int(file.split("-")[1].split(".")[0])
steps.append(step)
mmlu_scores.append(data["results"]["mmlu"]["acc"])
# Plot
plt.plot(steps, mmlu_scores)
plt.xlabel("Training Step")
plt.ylabel("MMLU Accuracy")
plt.title("Training Progress")
plt.savefig("training_curve.png")
```
### Workflow 3: Compare multiple models
Benchmark suite for model comparison.
```
Model Comparison:
- [ ] Step 1: Define model list
- [ ] Step 2: Run evaluations
- [ ] Step 3: Generate comparison table
```
**Step 1: Define model list**
```bash
# models.txt
meta-llama/Llama-2-7b-hf
meta-llama/Llama-2-13b-hf
mistralai/Mistral-7B-v0.1
microsoft/phi-2
```
**Step 2: Run evaluations**
```bash
#!/bin/bash
# eval_all_models.sh
TASKS="mmlu,gsm8k,hellaswag,truthfulqa"
while read model; do
echo "Evaluating $model"
# Extract model name for output file
model_name=$(echo $model | sed 's/\//-/g')
lm_eval --model hf \
--model_args pretrained=$model,dtype=bfloat16 \
--tasks $TASKS \
--num_fewshot 5 \
--batch_size auto \
--output_path results/$model_name.json
done < models.txt
```
**Step 3: Generate comparison table**
```python
import json
import pandas as pd
models = [
"meta-llama-Llama-2-7b-hf",
"meta-llama-Llama-2-13b-hf",
"mistralai-Mistral-7B-v0.1",
"microsoft-phi-2"
]
tasks = ["mmlu", "gsm8k", "hellaswag", "truthfulqa"]
results = []
for model in models:
with open(f"results/{model}.json") as f:
data = json.load(f)
row = {"Model": model.replace("-", "/")}
for task in tasks:
# Get primary metric for each task
metrics = data["results"][task]
if "acc" in metrics:
row[task.upper()] = f"{metrics['acc']:.3f}"
elif "exact_match" in metrics:
row[task.upper()] = f"{metrics['exact_match']:.3f}"
results.append(row)
df = pd.DataFrame(results)
print(df.to_markdown(index=False))
```
Output:
```
| Model | MMLU | GSM8K | HELLASWAG | TRUTHFULQA |
|------------------------|-------|-------|-----------|------------|
| meta-llama/Llama-2-7b | 0.459 | 0.142 | 0.765 | 0.391 |
| meta-llama/Llama-2-13b | 0.549 | 0.287 | 0.801 | 0.430 |
| mistralai/Mistral-7B | 0.626 | 0.395 | 0.812 | 0.428 |
| microsoft/phi-2 | 0.560 | 0.613 | 0.682 | 0.447 |
```
### Workflow 4: Evaluate with vLLM (faster inference)
Use vLLM backend for 5-10x faster evaluation.
```
vLLM Evaluation:
- [ ] Step 1: Install vLLM
- [ ] Step 2: Configure vLLM backend
- [ ] Step 3: Run evaluation
```
**Step 1: Install vLLM**
```bash
pip install vllm
```
**Step 2: Configure vLLM backend**
```bash
lm_eval --model vllm \
--model_args pretrained=meta-llama/Llama-2-7b-hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8 \
--tasks mmlu \
--batch_size auto
```
**Step 3: Run evaluation**
vLLM is 5-10× faster than standard HuggingFace:
```bash
# Standard HF: ~2 hours for MMLU on 7B model
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu \
--batch_size 8
# vLLM: ~15-20 minutes for MMLU on 7B model
lm_eval --model vllm \
--model_args pretrained=meta-llama/Llama-2-7b-hf,tensor_parallel_size=2 \
--tasks mmlu \
--batch_size auto
```
## When to use vs alternatives
**Use lm-evaluation-harness when:**
- Benchmarking models for academic papers
- Comparing model quality across standard tasks
- Tracking training progress
- Reporting standardized metrics (everyone uses same prompts)
- Need reproducible evaluation
**Use alternatives instead:**
- **HELM** (Stanford): Broader evaluation (fairness, efficiency, calibration)
- **AlpacaEval**: Instruction-following evaluation with LLM judges
- **MT-Bench**: Conversational multi-turn evaluation
- **Custom scripts**: Domain-specific evaluation
## Common issues
**Issue: Evaluation too slow**
Use vLLM backend:
```bash
lm_eval --model vllm \
--model_args pretrained=model-name,tensor_parallel_size=2
```
Or reduce fewshot examples:
```bash
--num_fewshot 0 # Instead of 5
```
Or evaluate subset of MMLU:
```bash
--tasks mmlu_stem # Only STEM subjects
```
**Issue: Out of memory**
Reduce batch size:
```bash
--batch_size 1 # Or --batch_size auto
```
Use quantization:
```bash
--model_args pretrained=model-name,load_in_8bit=True
```
Enable CPU offloading:
```bash
--model_args pretrained=model-name,device_map=auto,offload_folder=offload
```
**Issue: Different results than reported**
Check fewshot count:
```bash
--num_fewshot 5 # Most papers use 5-shot
```
Check exact task name:
```bash
--tasks mmlu # Not mmlu_direct or mmlu_fewshot
```
Verify model and tokenizer match:
```bash
--model_args pretrained=model-name,tokenizer=same-model-name
```
**Issue: HumanEval not executing code**
Install execution dependencies:
```bash
pip install human-eval
```
Enable code execution:
```bash
lm_eval --model hf \
--model_args pretrained=model-name \
--tasks humaneval \
--allow_code_execution # Required for HumanEval
```
## Advanced topics
**Benchmark descriptions**: See [references/benchmark-guide.md](references/benchmark-guide.md) for detailed description of all 60+ tasks, what they measure, and interpretation.
**Custom tasks**: See [references/custom-tasks.md](references/custom-tasks.md) for creating domain-specific evaluation tasks.
**API evaluation**: See [references/api-evaluation.md](references/api-evaluation.md) for evaluating OpenAI, Anthropic, and other API models.
**Multi-GPU strategies**: See [references/distributed-eval.md](references/distributed-eval.md) for data parallel and tensor parallel evaluation.
## Hardware requirements
- **GPU**: NVIDIA (CUDA 11.8+), works on CPU (very slow)
- **VRAM**:
- 7B model: 16GB (bf16) or 8GB (8-bit)
- 13B model: 28GB (bf16) or 14GB (8-bit)
- 70B model: Requires multi-GPU or quantization
- **Time** (7B model, single A100):
- HellaSwag: 10 minutes
- GSM8K: 5 minutes
- MMLU (full): 2 hours
- HumanEval: 20 minutes
## Resources
- GitHub: https://github.com/EleutherAI/lm-evaluation-harness
- Docs: https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs
- Task library: 60+ tasks including MMLU, GSM8K, HumanEval, TruthfulQA, HellaSwag, ARC, WinoGrande, etc.
- Leaderboard: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard (uses this harness)

View file

@ -0,0 +1,490 @@
# API Evaluation
Guide to evaluating OpenAI, Anthropic, and other API-based language models.
## Overview
The lm-evaluation-harness supports evaluating API-based models through a unified `TemplateAPI` interface. This allows benchmarking of:
- OpenAI models (GPT-4, GPT-3.5, etc.)
- Anthropic models (Claude 3, Claude 2, etc.)
- Local OpenAI-compatible APIs
- Custom API endpoints
**Why evaluate API models**:
- Benchmark closed-source models
- Compare API models to open models
- Validate API performance
- Track model updates over time
## Supported API Models
| Provider | Model Type | Request Types | Logprobs |
|----------|------------|---------------|----------|
| OpenAI (completions) | `openai-completions` | All | ✅ Yes |
| OpenAI (chat) | `openai-chat-completions` | `generate_until` only | ❌ No |
| Anthropic (completions) | `anthropic-completions` | All | ❌ No |
| Anthropic (chat) | `anthropic-chat` | `generate_until` only | ❌ No |
| Local (OpenAI-compatible) | `local-completions` | Depends on server | Varies |
**Note**: Models without logprobs can only be evaluated on generation tasks, not perplexity or loglikelihood tasks.
## OpenAI Models
### Setup
```bash
export OPENAI_API_KEY=sk-...
```
### Completion Models (Legacy)
**Available models**: `davinci-002`, `babbage-002`
```bash
lm_eval --model openai-completions \
--model_args model=davinci-002 \
--tasks lambada_openai,hellaswag \
--batch_size auto
```
**Supports**:
- `generate_until`: ✅
- `loglikelihood`: ✅
- `loglikelihood_rolling`: ✅
### Chat Models
**Available models**: `gpt-4`, `gpt-4-turbo`, `gpt-3.5-turbo`
```bash
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu,gsm8k,humaneval \
--num_fewshot 5 \
--batch_size auto
```
**Supports**:
- `generate_until`: ✅
- `loglikelihood`: ❌ (no logprobs)
- `loglikelihood_rolling`: ❌
**Important**: Chat models don't provide logprobs, so they can only be used with generation tasks (MMLU, GSM8K, HumanEval), not perplexity tasks.
### Configuration Options
```bash
lm_eval --model openai-chat-completions \
--model_args \
model=gpt-4-turbo,\
base_url=https://api.openai.com/v1,\
num_concurrent=5,\
max_retries=3,\
timeout=60,\
batch_size=auto
```
**Parameters**:
- `model`: Model identifier (required)
- `base_url`: API endpoint (default: OpenAI)
- `num_concurrent`: Concurrent requests (default: 5)
- `max_retries`: Retry failed requests (default: 3)
- `timeout`: Request timeout in seconds (default: 60)
- `tokenizer`: Tokenizer to use (default: matches model)
- `tokenizer_backend`: `"tiktoken"` or `"huggingface"`
### Cost Management
OpenAI charges per token. Estimate costs before running:
```python
# Rough estimate
num_samples = 1000
avg_tokens_per_sample = 500 # input + output
cost_per_1k_tokens = 0.01 # GPT-3.5 Turbo
total_cost = (num_samples * avg_tokens_per_sample / 1000) * cost_per_1k_tokens
print(f"Estimated cost: ${total_cost:.2f}")
```
**Cost-saving tips**:
- Use `--limit N` for testing
- Start with `gpt-3.5-turbo` before `gpt-4`
- Set `max_gen_toks` to minimum needed
- Use `num_fewshot=0` for zero-shot when possible
## Anthropic Models
### Setup
```bash
export ANTHROPIC_API_KEY=sk-ant-...
```
### Completion Models (Legacy)
```bash
lm_eval --model anthropic-completions \
--model_args model=claude-2.1 \
--tasks lambada_openai,hellaswag \
--batch_size auto
```
### Chat Models (Recommended)
**Available models**: `claude-3-5-sonnet-20241022`, `claude-3-opus-20240229`, `claude-3-sonnet-20240229`, `claude-3-haiku-20240307`
```bash
lm_eval --model anthropic-chat \
--model_args model=claude-3-5-sonnet-20241022 \
--tasks mmlu,gsm8k,humaneval \
--num_fewshot 5 \
--batch_size auto
```
**Aliases**: `anthropic-chat-completions` (same as `anthropic-chat`)
### Configuration Options
```bash
lm_eval --model anthropic-chat \
--model_args \
model=claude-3-5-sonnet-20241022,\
base_url=https://api.anthropic.com,\
num_concurrent=5,\
max_retries=3,\
timeout=60
```
### Cost Management
Anthropic pricing (as of 2024):
- Claude 3.5 Sonnet: $3.00 / 1M input, $15.00 / 1M output
- Claude 3 Opus: $15.00 / 1M input, $75.00 / 1M output
- Claude 3 Haiku: $0.25 / 1M input, $1.25 / 1M output
**Budget-friendly strategy**:
```bash
# Test on small sample first
lm_eval --model anthropic-chat \
--model_args model=claude-3-haiku-20240307 \
--tasks mmlu \
--limit 100
# Then run full eval on best model
lm_eval --model anthropic-chat \
--model_args model=claude-3-5-sonnet-20241022 \
--tasks mmlu \
--num_fewshot 5
```
## Local OpenAI-Compatible APIs
Many local inference servers expose OpenAI-compatible APIs (vLLM, Text Generation Inference, llama.cpp, Ollama).
### vLLM Local Server
**Start server**:
```bash
vllm serve meta-llama/Llama-2-7b-hf \
--host 0.0.0.0 \
--port 8000
```
**Evaluate**:
```bash
lm_eval --model local-completions \
--model_args \
model=meta-llama/Llama-2-7b-hf,\
base_url=http://localhost:8000/v1,\
num_concurrent=1 \
--tasks mmlu,gsm8k \
--batch_size auto
```
### Text Generation Inference (TGI)
**Start server**:
```bash
docker run --gpus all --shm-size 1g -p 8080:80 \
ghcr.io/huggingface/text-generation-inference:latest \
--model-id meta-llama/Llama-2-7b-hf
```
**Evaluate**:
```bash
lm_eval --model local-completions \
--model_args \
model=meta-llama/Llama-2-7b-hf,\
base_url=http://localhost:8080/v1 \
--tasks hellaswag,arc_challenge
```
### Ollama
**Start server**:
```bash
ollama serve
ollama pull llama2:7b
```
**Evaluate**:
```bash
lm_eval --model local-completions \
--model_args \
model=llama2:7b,\
base_url=http://localhost:11434/v1 \
--tasks mmlu
```
### llama.cpp Server
**Start server**:
```bash
./server -m models/llama-2-7b.gguf --host 0.0.0.0 --port 8080
```
**Evaluate**:
```bash
lm_eval --model local-completions \
--model_args \
model=llama2,\
base_url=http://localhost:8080/v1 \
--tasks gsm8k
```
## Custom API Implementation
For custom API endpoints, subclass `TemplateAPI`:
### Create `my_api.py`
```python
from lm_eval.models.api_models import TemplateAPI
import requests
class MyCustomAPI(TemplateAPI):
"""Custom API model."""
def __init__(self, base_url, api_key, **kwargs):
super().__init__(base_url=base_url, **kwargs)
self.api_key = api_key
def _create_payload(self, messages, gen_kwargs):
"""Create API request payload."""
return {
"messages": messages,
"api_key": self.api_key,
**gen_kwargs
}
def parse_generations(self, response):
"""Parse generation response."""
return response.json()["choices"][0]["text"]
def parse_logprobs(self, response):
"""Parse logprobs (if available)."""
# Return None if API doesn't provide logprobs
logprobs = response.json().get("logprobs")
if logprobs:
return logprobs["token_logprobs"]
return None
```
### Register and Use
```python
from lm_eval import evaluator
from my_api import MyCustomAPI
model = MyCustomAPI(
base_url="https://api.example.com/v1",
api_key="your-key"
)
results = evaluator.simple_evaluate(
model=model,
tasks=["mmlu", "gsm8k"],
num_fewshot=5,
batch_size="auto"
)
```
## Comparing API and Open Models
### Side-by-Side Evaluation
```bash
# Evaluate OpenAI GPT-4
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu,gsm8k,hellaswag \
--num_fewshot 5 \
--output_path results/gpt4.json
# Evaluate open Llama 2 70B
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-70b-hf,dtype=bfloat16 \
--tasks mmlu,gsm8k,hellaswag \
--num_fewshot 5 \
--output_path results/llama2-70b.json
# Compare results
python scripts/compare_results.py \
results/gpt4.json \
results/llama2-70b.json
```
### Typical Comparisons
| Model | MMLU | GSM8K | HumanEval | Cost |
|-------|------|-------|-----------|------|
| GPT-4 Turbo | 86.4% | 92.0% | 67.0% | $$$$ |
| Claude 3 Opus | 86.8% | 95.0% | 84.9% | $$$$ |
| GPT-3.5 Turbo | 70.0% | 57.1% | 48.1% | $$ |
| Llama 2 70B | 68.9% | 56.8% | 29.9% | Free (self-host) |
| Mixtral 8x7B | 70.6% | 58.4% | 40.2% | Free (self-host) |
## Best Practices
### Rate Limiting
Respect API rate limits:
```bash
lm_eval --model openai-chat-completions \
--model_args \
model=gpt-4-turbo,\
num_concurrent=3,\ # Lower concurrency
timeout=120 \ # Longer timeout
--tasks mmlu
```
### Reproducibility
Set temperature to 0 for deterministic results:
```bash
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu \
--gen_kwargs temperature=0.0
```
Or use `seed` for sampling:
```bash
lm_eval --model anthropic-chat \
--model_args model=claude-3-5-sonnet-20241022 \
--tasks gsm8k \
--gen_kwargs temperature=0.7,seed=42
```
### Caching
API models automatically cache responses to avoid redundant calls:
```bash
# First run: makes API calls
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu \
--limit 100
# Second run: uses cache (instant, free)
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu \
--limit 100
```
Cache location: `~/.cache/lm_eval/`
### Error Handling
APIs can fail. Use retries:
```bash
lm_eval --model openai-chat-completions \
--model_args \
model=gpt-4-turbo,\
max_retries=5,\
timeout=120 \
--tasks mmlu
```
## Troubleshooting
### "Authentication failed"
Check API key:
```bash
echo $OPENAI_API_KEY # Should print sk-...
echo $ANTHROPIC_API_KEY # Should print sk-ant-...
```
### "Rate limit exceeded"
Reduce concurrency:
```bash
--model_args num_concurrent=1
```
Or add delays between requests.
### "Timeout error"
Increase timeout:
```bash
--model_args timeout=180
```
### "Model not found"
For local APIs, verify server is running:
```bash
curl http://localhost:8000/v1/models
```
### Cost Runaway
Use `--limit` for testing:
```bash
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu \
--limit 50 # Only 50 samples
```
## Advanced Features
### Custom Headers
```bash
lm_eval --model local-completions \
--model_args \
base_url=http://api.example.com/v1,\
header="Authorization: Bearer token,X-Custom: value"
```
### Disable SSL Verification (Development Only)
```bash
lm_eval --model local-completions \
--model_args \
base_url=https://localhost:8000/v1,\
verify_certificate=false
```
### Custom Tokenizer
```bash
lm_eval --model openai-chat-completions \
--model_args \
model=gpt-4-turbo,\
tokenizer=gpt2,\
tokenizer_backend=huggingface
```
## References
- OpenAI API: https://platform.openai.com/docs/api-reference
- Anthropic API: https://docs.anthropic.com/claude/reference
- TemplateAPI: `lm_eval/models/api_models.py`
- OpenAI models: `lm_eval/models/openai_completions.py`
- Anthropic models: `lm_eval/models/anthropic_llms.py`

View file

@ -0,0 +1,488 @@
# Benchmark Guide
Complete guide to all 60+ evaluation tasks in lm-evaluation-harness, what they measure, and how to interpret results.
## Overview
The lm-evaluation-harness includes 60+ benchmarks spanning:
- Language understanding (MMLU, GLUE)
- Mathematical reasoning (GSM8K, MATH)
- Code generation (HumanEval, MBPP)
- Instruction following (IFEval, AlpacaEval)
- Long-context understanding (LongBench)
- Multilingual capabilities (AfroBench, NorEval)
- Reasoning (BBH, ARC)
- Truthfulness (TruthfulQA)
**List all tasks**:
```bash
lm_eval --tasks list
```
## Major Benchmarks
### MMLU (Massive Multitask Language Understanding)
**What it measures**: Broad knowledge across 57 subjects (STEM, humanities, social sciences, law).
**Task variants**:
- `mmlu`: Original 57-subject benchmark
- `mmlu_pro`: More challenging version with reasoning-focused questions
- `mmlu_prox`: Multilingual extension
**Format**: Multiple choice (4 options)
**Example**:
```
Question: What is the capital of France?
A. Berlin
B. Paris
C. London
D. Madrid
Answer: B
```
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu \
--num_fewshot 5
```
**Interpretation**:
- Random: 25% (chance)
- GPT-3 (175B): 43.9%
- GPT-4: 86.4%
- Human expert: ~90%
**Good for**: Assessing general knowledge and domain expertise.
### GSM8K (Grade School Math 8K)
**What it measures**: Mathematical reasoning on grade-school level word problems.
**Task variants**:
- `gsm8k`: Base task
- `gsm8k_cot`: With chain-of-thought prompting
- `gsm_plus`: Adversarial variant with perturbations
**Format**: Free-form generation, extract numerical answer
**Example**:
```
Question: A baker made 200 cookies. He sold 3/5 of them in the morning and 1/4 of the remaining in the afternoon. How many cookies does he have left?
Answer: 60
```
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks gsm8k \
--num_fewshot 5
```
**Interpretation**:
- Random: ~0%
- GPT-3 (175B): 17.0%
- GPT-4: 92.0%
- Llama 2 70B: 56.8%
**Good for**: Testing multi-step reasoning and arithmetic.
### HumanEval
**What it measures**: Python code generation from docstrings (functional correctness).
**Task variants**:
- `humaneval`: Standard benchmark
- `humaneval_instruct`: For instruction-tuned models
**Format**: Code generation, execution-based evaluation
**Example**:
```python
def has_close_elements(numbers: List[float], threshold: float) -> bool:
""" Check if in given list of numbers, are any two numbers closer to each other than
given threshold.
>>> has_close_elements([1.0, 2.0, 3.0], 0.5)
False
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
True
"""
```
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=codellama/CodeLlama-7b-hf \
--tasks humaneval \
--batch_size 1
```
**Interpretation**:
- Random: 0%
- GPT-3 (175B): 0%
- Codex: 28.8%
- GPT-4: 67.0%
- Code Llama 34B: 53.7%
**Good for**: Evaluating code generation capabilities.
### BBH (BIG-Bench Hard)
**What it measures**: 23 challenging reasoning tasks where models previously failed to beat humans.
**Categories**:
- Logical reasoning
- Math word problems
- Social understanding
- Algorithmic reasoning
**Format**: Multiple choice and free-form
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks bbh \
--num_fewshot 3
```
**Interpretation**:
- Random: ~25%
- GPT-3 (175B): 33.9%
- PaLM 540B: 58.3%
- GPT-4: 86.7%
**Good for**: Testing advanced reasoning capabilities.
### IFEval (Instruction-Following Evaluation)
**What it measures**: Ability to follow specific, verifiable instructions.
**Instruction types**:
- Format constraints (e.g., "answer in 3 sentences")
- Length constraints (e.g., "use at least 100 words")
- Content constraints (e.g., "include the word 'banana'")
- Structural constraints (e.g., "use bullet points")
**Format**: Free-form generation with rule-based verification
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-chat-hf \
--tasks ifeval \
--batch_size auto
```
**Interpretation**:
- Measures: Instruction adherence (not quality)
- GPT-4: 86% instruction following
- Claude 2: 84%
**Good for**: Evaluating chat/instruct models.
### GLUE (General Language Understanding Evaluation)
**What it measures**: Natural language understanding across 9 tasks.
**Tasks**:
- `cola`: Grammatical acceptability
- `sst2`: Sentiment analysis
- `mrpc`: Paraphrase detection
- `qqp`: Question pairs
- `stsb`: Semantic similarity
- `mnli`: Natural language inference
- `qnli`: Question answering NLI
- `rte`: Recognizing textual entailment
- `wnli`: Winograd schemas
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=bert-base-uncased \
--tasks glue \
--num_fewshot 0
```
**Interpretation**:
- BERT Base: 78.3 (GLUE score)
- RoBERTa Large: 88.5
- Human baseline: 87.1
**Good for**: Encoder-only models, fine-tuning baselines.
### LongBench
**What it measures**: Long-context understanding (4K-32K tokens).
**21 tasks covering**:
- Single-document QA
- Multi-document QA
- Summarization
- Few-shot learning
- Code completion
- Synthetic tasks
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks longbench \
--batch_size 1
```
**Interpretation**:
- Tests context utilization
- Many models struggle beyond 4K tokens
- GPT-4 Turbo: 54.3%
**Good for**: Evaluating long-context models.
## Additional Benchmarks
### TruthfulQA
**What it measures**: Model's propensity to be truthful vs. generate plausible-sounding falsehoods.
**Format**: Multiple choice with 4-5 options
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks truthfulqa_mc2 \
--batch_size auto
```
**Interpretation**:
- Larger models often score worse (more convincing lies)
- GPT-3: 58.8%
- GPT-4: 59.0%
- Human: ~94%
### ARC (AI2 Reasoning Challenge)
**What it measures**: Grade-school science questions.
**Variants**:
- `arc_easy`: Easier questions
- `arc_challenge`: Harder questions requiring reasoning
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks arc_challenge \
--num_fewshot 25
```
**Interpretation**:
- ARC-Easy: Most models >80%
- ARC-Challenge random: 25%
- GPT-4: 96.3%
### HellaSwag
**What it measures**: Commonsense reasoning about everyday situations.
**Format**: Choose most plausible continuation
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks hellaswag \
--num_fewshot 10
```
**Interpretation**:
- Random: 25%
- GPT-3: 78.9%
- Llama 2 70B: 85.3%
### WinoGrande
**What it measures**: Commonsense reasoning via pronoun resolution.
**Example**:
```
The trophy doesn't fit in the brown suitcase because _ is too large.
A. the trophy
B. the suitcase
```
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks winogrande \
--num_fewshot 5
```
### PIQA
**What it measures**: Physical commonsense reasoning.
**Example**: "To clean a keyboard, use compressed air or..."
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks piqa
```
## Multilingual Benchmarks
### AfroBench
**What it measures**: Performance across 64 African languages.
**15 tasks**: NLU, text generation, knowledge, QA, math reasoning
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks afrobench
```
### NorEval
**What it measures**: Norwegian language understanding (9 task categories).
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=NbAiLab/nb-gpt-j-6B \
--tasks noreval
```
## Domain-Specific Benchmarks
### MATH
**What it measures**: High-school competition math problems.
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks math \
--num_fewshot 4
```
**Interpretation**:
- Very challenging
- GPT-4: 42.5%
- Minerva 540B: 33.6%
### MBPP (Mostly Basic Python Problems)
**What it measures**: Python programming from natural language descriptions.
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=codellama/CodeLlama-7b-hf \
--tasks mbpp \
--batch_size 1
```
### DROP
**What it measures**: Reading comprehension requiring discrete reasoning.
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks drop
```
## Benchmark Selection Guide
### For General Purpose Models
Run this suite:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu,gsm8k,hellaswag,arc_challenge,truthfulqa_mc2 \
--num_fewshot 5
```
### For Code Models
```bash
lm_eval --model hf \
--model_args pretrained=codellama/CodeLlama-7b-hf \
--tasks humaneval,mbpp \
--batch_size 1
```
### For Chat/Instruct Models
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-chat-hf \
--tasks ifeval,mmlu,gsm8k_cot \
--batch_size auto
```
### For Long Context Models
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-3.1-8B \
--tasks longbench \
--batch_size 1
```
## Interpreting Results
### Understanding Metrics
**Accuracy**: Percentage of correct answers (most common)
**Exact Match (EM)**: Requires exact string match (strict)
**F1 Score**: Balances precision and recall
**BLEU/ROUGE**: Text generation similarity
**Pass@k**: Percentage passing when generating k samples
### Typical Score Ranges
| Model Size | MMLU | GSM8K | HumanEval | HellaSwag |
|------------|------|-------|-----------|-----------|
| 7B | 40-50% | 10-20% | 5-15% | 70-80% |
| 13B | 45-55% | 20-35% | 15-25% | 75-82% |
| 70B | 60-70% | 50-65% | 35-50% | 82-87% |
| GPT-4 | 86% | 92% | 67% | 95% |
### Red Flags
- **All tasks at random chance**: Model not trained properly
- **Exact 0% on generation tasks**: Likely format/parsing issue
- **Huge variance across runs**: Check seed/sampling settings
- **Better than GPT-4 on everything**: Likely contamination
## Best Practices
1. **Always report few-shot setting**: 0-shot, 5-shot, etc.
2. **Run multiple seeds**: Report mean ± std
3. **Check for data contamination**: Search training data for benchmark examples
4. **Compare to published baselines**: Validate your setup
5. **Report all hyperparameters**: Model, batch size, max tokens, temperature
## References
- Task list: `lm_eval --tasks list`
- Task README: `lm_eval/tasks/README.md`
- Papers: See individual benchmark papers

View file

@ -0,0 +1,602 @@
# Custom Tasks
Complete guide to creating domain-specific evaluation tasks in lm-evaluation-harness.
## Overview
Custom tasks allow you to evaluate models on your own datasets and metrics. Tasks are defined using YAML configuration files with optional Python utilities for complex logic.
**Why create custom tasks**:
- Evaluate on proprietary/domain-specific data
- Test specific capabilities not covered by existing benchmarks
- Create evaluation pipelines for internal models
- Reproduce research experiments
## Quick Start
### Minimal Custom Task
Create `my_tasks/simple_qa.yaml`:
```yaml
task: simple_qa
dataset_path: data/simple_qa.jsonl
output_type: generate_until
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
```
**Run it**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks simple_qa \
--include_path my_tasks/
```
## Task Configuration Reference
### Essential Fields
```yaml
# Task identification
task: my_custom_task # Unique task name (required)
task_alias: "My Task" # Display name
tag: # Tags for grouping
- custom
- domain_specific
# Dataset configuration
dataset_path: data/my_data.jsonl # HuggingFace dataset or local path
dataset_name: default # Subset name (if applicable)
training_split: train
validation_split: validation
test_split: test
# Evaluation configuration
output_type: generate_until # or loglikelihood, multiple_choice
num_fewshot: 5 # Number of few-shot examples
batch_size: auto # Batch size
# Prompt templates (Jinja2)
doc_to_text: "Question: {{question}}"
doc_to_target: "{{answer}}"
# Metrics
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
# Metadata
metadata:
version: 1.0
```
### Output Types
**`generate_until`**: Free-form generation
```yaml
output_type: generate_until
generation_kwargs:
max_gen_toks: 256
until:
- "\n"
- "."
temperature: 0.0
```
**`loglikelihood`**: Compute log probability of targets
```yaml
output_type: loglikelihood
# Used for perplexity, classification
```
**`multiple_choice`**: Choose from options
```yaml
output_type: multiple_choice
doc_to_choice: "{{choices}}" # List of choices
```
## Data Formats
### Local JSONL File
`data/my_data.jsonl`:
```json
{"question": "What is 2+2?", "answer": "4"}
{"question": "Capital of France?", "answer": "Paris"}
```
**Task config**:
```yaml
dataset_path: data/my_data.jsonl
dataset_kwargs:
data_files:
test: data/my_data.jsonl
```
### HuggingFace Dataset
```yaml
dataset_path: squad
dataset_name: plain_text
test_split: validation
```
### CSV File
`data/my_data.csv`:
```csv
question,answer,category
What is 2+2?,4,math
Capital of France?,Paris,geography
```
**Task config**:
```yaml
dataset_path: data/my_data.csv
dataset_kwargs:
data_files:
test: data/my_data.csv
```
## Prompt Engineering
### Simple Template
```yaml
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}"
```
### Conditional Logic
```yaml
doc_to_text: |
{% if context %}
Context: {{context}}
{% endif %}
Question: {{question}}
Answer:
```
### Multiple Choice
```yaml
doc_to_text: |
Question: {{question}}
A. {{choices[0]}}
B. {{choices[1]}}
C. {{choices[2]}}
D. {{choices[3]}}
Answer:
doc_to_target: "{{ 'ABCD'[answer_idx] }}"
doc_to_choice: ["A", "B", "C", "D"]
```
### Few-Shot Formatting
```yaml
fewshot_delimiter: "\n\n" # Between examples
target_delimiter: " " # Between question and answer
doc_to_text: "Q: {{question}}"
doc_to_target: "A: {{answer}}"
```
## Custom Python Functions
For complex logic, use Python functions in `utils.py`.
### Create `my_tasks/utils.py`
```python
def process_docs(dataset):
"""Preprocess documents."""
def _process(doc):
# Custom preprocessing
doc["question"] = doc["question"].strip().lower()
return doc
return dataset.map(_process)
def doc_to_text(doc):
"""Custom prompt formatting."""
context = doc.get("context", "")
question = doc["question"]
if context:
return f"Context: {context}\nQuestion: {question}\nAnswer:"
return f"Question: {question}\nAnswer:"
def doc_to_target(doc):
"""Custom target extraction."""
return doc["answer"].strip().lower()
def aggregate_scores(items):
"""Custom metric aggregation."""
correct = sum(1 for item in items if item == 1.0)
total = len(items)
return correct / total if total > 0 else 0.0
```
### Use in Task Config
```yaml
task: my_custom_task
dataset_path: data/my_data.jsonl
# Use Python functions
process_docs: !function utils.process_docs
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
metric_list:
- metric: exact_match
aggregation: !function utils.aggregate_scores
higher_is_better: true
```
## Real-World Examples
### Example 1: Domain QA Task
**Goal**: Evaluate medical question answering.
`medical_qa/medical_qa.yaml`:
```yaml
task: medical_qa
dataset_path: data/medical_qa.jsonl
output_type: generate_until
num_fewshot: 3
doc_to_text: |
Medical Question: {{question}}
Context: {{context}}
Answer (be concise):
doc_to_target: "{{answer}}"
generation_kwargs:
max_gen_toks: 100
until:
- "\n\n"
temperature: 0.0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
- metric: !function utils.medical_f1
aggregation: mean
higher_is_better: true
filter_list:
- name: lowercase
filter:
- function: lowercase
- function: remove_whitespace
metadata:
version: 1.0
domain: medical
```
`medical_qa/utils.py`:
```python
from sklearn.metrics import f1_score
import re
def medical_f1(predictions, references):
"""Custom F1 for medical terms."""
pred_terms = set(extract_medical_terms(predictions[0]))
ref_terms = set(extract_medical_terms(references[0]))
if not pred_terms and not ref_terms:
return 1.0
if not pred_terms or not ref_terms:
return 0.0
tp = len(pred_terms & ref_terms)
fp = len(pred_terms - ref_terms)
fn = len(ref_terms - pred_terms)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
def extract_medical_terms(text):
"""Extract medical terminology."""
# Custom logic
return re.findall(r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)*\b', text)
```
### Example 2: Code Evaluation
`code_eval/python_challenges.yaml`:
```yaml
task: python_challenges
dataset_path: data/python_problems.jsonl
output_type: generate_until
num_fewshot: 0
doc_to_text: |
Write a Python function to solve:
{{problem_statement}}
Function signature:
{{function_signature}}
doc_to_target: "{{canonical_solution}}"
generation_kwargs:
max_gen_toks: 512
until:
- "\n\nclass"
- "\n\ndef"
temperature: 0.2
metric_list:
- metric: !function utils.execute_code
aggregation: mean
higher_is_better: true
process_results: !function utils.process_code_results
metadata:
version: 1.0
```
`code_eval/utils.py`:
```python
import subprocess
import json
def execute_code(predictions, references):
"""Execute generated code against test cases."""
generated_code = predictions[0]
test_cases = json.loads(references[0])
try:
# Execute code with test cases
for test_input, expected_output in test_cases:
result = execute_with_timeout(generated_code, test_input, timeout=5)
if result != expected_output:
return 0.0
return 1.0
except Exception:
return 0.0
def execute_with_timeout(code, input_data, timeout=5):
"""Safely execute code with timeout."""
# Implementation with subprocess and timeout
pass
def process_code_results(doc, results):
"""Process code execution results."""
return {
"passed": results[0] == 1.0,
"generated_code": results[1]
}
```
### Example 3: Instruction Following
`instruction_eval/instruction_eval.yaml`:
```yaml
task: instruction_following
dataset_path: data/instructions.jsonl
output_type: generate_until
num_fewshot: 0
doc_to_text: |
Instruction: {{instruction}}
{% if constraints %}
Constraints: {{constraints}}
{% endif %}
Response:
doc_to_target: "{{expected_response}}"
generation_kwargs:
max_gen_toks: 256
temperature: 0.7
metric_list:
- metric: !function utils.check_constraints
aggregation: mean
higher_is_better: true
- metric: !function utils.semantic_similarity
aggregation: mean
higher_is_better: true
process_docs: !function utils.add_constraint_checkers
```
`instruction_eval/utils.py`:
```python
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('all-MiniLM-L6-v2')
def check_constraints(predictions, references):
"""Check if response satisfies constraints."""
response = predictions[0]
constraints = json.loads(references[0])
satisfied = 0
total = len(constraints)
for constraint in constraints:
if verify_constraint(response, constraint):
satisfied += 1
return satisfied / total if total > 0 else 1.0
def verify_constraint(response, constraint):
"""Verify single constraint."""
if constraint["type"] == "length":
return len(response.split()) >= constraint["min_words"]
elif constraint["type"] == "contains":
return constraint["keyword"] in response.lower()
# Add more constraint types
return True
def semantic_similarity(predictions, references):
"""Compute semantic similarity."""
pred_embedding = model.encode(predictions[0])
ref_embedding = model.encode(references[0])
return float(util.cos_sim(pred_embedding, ref_embedding))
def add_constraint_checkers(dataset):
"""Parse constraints into verifiable format."""
def _parse(doc):
# Parse constraint string into structured format
doc["parsed_constraints"] = parse_constraints(doc.get("constraints", ""))
return doc
return dataset.map(_parse)
```
## Advanced Features
### Output Filtering
```yaml
filter_list:
- name: extract_answer
filter:
- function: regex
regex_pattern: "Answer: (.*)"
group: 1
- function: lowercase
- function: strip_whitespace
```
### Multiple Metrics
```yaml
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
- metric: bleu
aggregation: mean
higher_is_better: true
```
### Task Groups
Create `my_tasks/_default.yaml`:
```yaml
group: my_eval_suite
task:
- simple_qa
- medical_qa
- python_challenges
```
**Run entire suite**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks my_eval_suite \
--include_path my_tasks/
```
## Testing Your Task
### Validate Configuration
```bash
# Test task loading
lm_eval --tasks my_custom_task --include_path my_tasks/ --limit 0
# Run on 5 samples
lm_eval --model hf \
--model_args pretrained=gpt2 \
--tasks my_custom_task \
--include_path my_tasks/ \
--limit 5
```
### Debug Mode
```bash
lm_eval --model hf \
--model_args pretrained=gpt2 \
--tasks my_custom_task \
--include_path my_tasks/ \
--limit 1 \
--log_samples # Save input/output samples
```
## Best Practices
1. **Start simple**: Test with minimal config first
2. **Version your tasks**: Use `metadata.version`
3. **Document your metrics**: Explain custom metrics in comments
4. **Test with multiple models**: Ensure robustness
5. **Validate on known examples**: Include sanity checks
6. **Use filters carefully**: Can hide errors
7. **Handle edge cases**: Empty strings, missing fields
## Common Patterns
### Classification Task
```yaml
output_type: loglikelihood
doc_to_text: "Text: {{text}}\nLabel:"
doc_to_target: " {{label}}" # Space prefix important!
metric_list:
- metric: acc
aggregation: mean
```
### Perplexity Evaluation
```yaml
output_type: loglikelihood_rolling
doc_to_text: "{{text}}"
metric_list:
- metric: perplexity
aggregation: perplexity
```
### Ranking Task
```yaml
output_type: loglikelihood
doc_to_text: "Query: {{query}}\nPassage: {{passage}}\nRelevant:"
doc_to_target: [" Yes", " No"]
metric_list:
- metric: acc
aggregation: mean
```
## Troubleshooting
**"Task not found"**: Check `--include_path` and task name
**Empty results**: Verify `doc_to_text` and `doc_to_target` templates
**Metric errors**: Ensure metric names are correct (exact_match, not exact-match)
**Filter issues**: Test filters with `--log_samples`
**Python function not found**: Check `!function module.function_name` syntax
## References
- Task system: EleutherAI/lm-evaluation-harness docs
- Example tasks: `lm_eval/tasks/` directory
- TaskConfig: `lm_eval/api/task.py`

View file

@ -0,0 +1,519 @@
# Distributed Evaluation
Guide to running evaluation across multiple GPUs using data parallelism and tensor/pipeline parallelism.
## Overview
Distributed evaluation speeds up benchmarking by:
- **Data Parallelism**: Split evaluation samples across GPUs (each GPU has full model copy)
- **Tensor Parallelism**: Split model weights across GPUs (for large models)
- **Pipeline Parallelism**: Split model layers across GPUs (for very large models)
**When to use**:
- Data Parallel: Model fits on single GPU, want faster evaluation
- Tensor/Pipeline Parallel: Model too large for single GPU
## HuggingFace Models (`hf`)
### Data Parallelism (Recommended)
Each GPU loads a full copy of the model and processes a subset of evaluation data.
**Single Node (8 GPUs)**:
```bash
accelerate launch --multi_gpu --num_processes 8 \
-m lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf,dtype=bfloat16 \
--tasks mmlu,gsm8k,hellaswag \
--batch_size 16
```
**Speedup**: Near-linear (8 GPUs = ~8× faster)
**Memory**: Each GPU needs full model (7B model ≈ 14GB × 8 = 112GB total)
### Tensor Parallelism (Model Sharding)
Split model weights across GPUs for models too large for single GPU.
**Without accelerate launcher**:
```bash
lm_eval --model hf \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
parallelize=True,\
dtype=bfloat16 \
--tasks mmlu,gsm8k \
--batch_size 8
```
**With 8 GPUs**: 70B model (140GB) / 8 = 17.5GB per GPU ✅
**Advanced sharding**:
```bash
lm_eval --model hf \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
parallelize=True,\
device_map_option=auto,\
max_memory_per_gpu=40GB,\
max_cpu_memory=100GB,\
dtype=bfloat16 \
--tasks mmlu
```
**Options**:
- `device_map_option`: `"auto"` (default), `"balanced"`, `"balanced_low_0"`
- `max_memory_per_gpu`: Max memory per GPU (e.g., `"40GB"`)
- `max_cpu_memory`: Max CPU memory for offloading
- `offload_folder`: Disk offloading directory
### Combined Data + Tensor Parallelism
Use both for very large models.
**Example: 70B model on 16 GPUs (2 copies, 8 GPUs each)**:
```bash
accelerate launch --multi_gpu --num_processes 2 \
-m lm_eval --model hf \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
parallelize=True,\
dtype=bfloat16 \
--tasks mmlu \
--batch_size 8
```
**Result**: 2× speedup from data parallelism, 70B model fits via tensor parallelism
### Configuration with `accelerate config`
Create `~/.cache/huggingface/accelerate/default_config.yaml`:
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
num_machines: 1
num_processes: 8
gpu_ids: all
mixed_precision: bf16
```
**Then run**:
```bash
accelerate launch -m lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu
```
## vLLM Models (`vllm`)
vLLM provides highly optimized distributed inference.
### Tensor Parallelism
**Single Node (4 GPUs)**:
```bash
lm_eval --model vllm \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
tensor_parallel_size=4,\
dtype=auto,\
gpu_memory_utilization=0.9 \
--tasks mmlu,gsm8k \
--batch_size auto
```
**Memory**: 70B model split across 4 GPUs = ~35GB per GPU
### Data Parallelism
**Multiple model replicas**:
```bash
lm_eval --model vllm \
--model_args \
pretrained=meta-llama/Llama-2-7b-hf,\
data_parallel_size=4,\
dtype=auto,\
gpu_memory_utilization=0.8 \
--tasks hellaswag,arc_challenge \
--batch_size auto
```
**Result**: 4 model replicas = 4× throughput
### Combined Tensor + Data Parallelism
**Example: 8 GPUs = 4 TP × 2 DP**:
```bash
lm_eval --model vllm \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
tensor_parallel_size=4,\
data_parallel_size=2,\
dtype=auto,\
gpu_memory_utilization=0.85 \
--tasks mmlu \
--batch_size auto
```
**Result**: 70B model fits (TP=4), 2× speedup (DP=2)
### Multi-Node vLLM
vLLM doesn't natively support multi-node. Use Ray:
```bash
# Start Ray cluster
ray start --head --port=6379
# Run evaluation
lm_eval --model vllm \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
tensor_parallel_size=8,\
dtype=auto \
--tasks mmlu
```
## NVIDIA NeMo Models (`nemo_lm`)
### Data Replication
**8 replicas on 8 GPUs**:
```bash
torchrun --nproc-per-node=8 --no-python \
lm_eval --model nemo_lm \
--model_args \
path=/path/to/model.nemo,\
devices=8 \
--tasks hellaswag,arc_challenge \
--batch_size 32
```
**Speedup**: Near-linear (8× faster)
### Tensor Parallelism
**4-way tensor parallelism**:
```bash
torchrun --nproc-per-node=4 --no-python \
lm_eval --model nemo_lm \
--model_args \
path=/path/to/70b_model.nemo,\
devices=4,\
tensor_model_parallel_size=4 \
--tasks mmlu,gsm8k \
--batch_size 16
```
### Pipeline Parallelism
**2 TP × 2 PP on 4 GPUs**:
```bash
torchrun --nproc-per-node=4 --no-python \
lm_eval --model nemo_lm \
--model_args \
path=/path/to/model.nemo,\
devices=4,\
tensor_model_parallel_size=2,\
pipeline_model_parallel_size=2 \
--tasks mmlu \
--batch_size 8
```
**Constraint**: `devices = TP × PP`
### Multi-Node NeMo
Currently not supported by lm-evaluation-harness.
## SGLang Models (`sglang`)
### Tensor Parallelism
```bash
lm_eval --model sglang \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
tp_size=4,\
dtype=auto \
--tasks gsm8k \
--batch_size auto
```
### Data Parallelism (Deprecated)
**Note**: SGLang is deprecating data parallelism. Use tensor parallelism instead.
```bash
lm_eval --model sglang \
--model_args \
pretrained=meta-llama/Llama-2-7b-hf,\
dp_size=4,\
dtype=auto \
--tasks mmlu
```
## Performance Comparison
### 70B Model Evaluation (MMLU, 5-shot)
| Method | GPUs | Time | Memory/GPU | Notes |
|--------|------|------|------------|-------|
| HF (no parallel) | 1 | 8 hours | 140GB (OOM) | Won't fit |
| HF (TP=8) | 8 | 2 hours | 17.5GB | Slower, fits |
| HF (DP=8) | 8 | 1 hour | 140GB (OOM) | Won't fit |
| vLLM (TP=4) | 4 | 30 min | 35GB | Fast! |
| vLLM (TP=4, DP=2) | 8 | 15 min | 35GB | Fastest |
### 7B Model Evaluation (Multiple Tasks)
| Method | GPUs | Time | Speedup |
|--------|------|------|---------|
| HF (single) | 1 | 4 hours | 1× |
| HF (DP=4) | 4 | 1 hour | 4× |
| HF (DP=8) | 8 | 30 min | 8× |
| vLLM (DP=8) | 8 | 15 min | 16× |
**Takeaway**: vLLM is significantly faster than HuggingFace for inference.
## Choosing Parallelism Strategy
### Decision Tree
```
Model fits on single GPU?
├─ YES: Use data parallelism
│ ├─ HF: accelerate launch --multi_gpu --num_processes N
│ └─ vLLM: data_parallel_size=N (fastest)
└─ NO: Use tensor/pipeline parallelism
├─ Model < 70B:
│ └─ vLLM: tensor_parallel_size=4
├─ Model 70-175B:
│ ├─ vLLM: tensor_parallel_size=8
│ └─ Or HF: parallelize=True
└─ Model > 175B:
└─ Contact framework authors
```
### Memory Estimation
**Rule of thumb**:
```
Memory (GB) = Parameters (B) × Precision (bytes) × 1.2 (overhead)
```
**Examples**:
- 7B FP16: 7 × 2 × 1.2 = 16.8GB ✅ Fits A100 40GB
- 13B FP16: 13 × 2 × 1.2 = 31.2GB ✅ Fits A100 40GB
- 70B FP16: 70 × 2 × 1.2 = 168GB ❌ Need TP=4 or TP=8
- 70B BF16: 70 × 2 × 1.2 = 168GB (same as FP16)
**With tensor parallelism**:
```
Memory per GPU = Total Memory / TP
```
- 70B on 4 GPUs: 168GB / 4 = 42GB per GPU ✅
- 70B on 8 GPUs: 168GB / 8 = 21GB per GPU ✅
## Multi-Node Evaluation
### HuggingFace with SLURM
**Submit job**:
```bash
#!/bin/bash
#SBATCH --nodes=4
#SBATCH --gpus-per-node=8
#SBATCH --ntasks-per-node=1
srun accelerate launch --multi_gpu \
--num_processes $((SLURM_NNODES * 8)) \
-m lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu,gsm8k,hellaswag \
--batch_size 16
```
**Submit**:
```bash
sbatch eval_job.sh
```
### Manual Multi-Node Setup
**On each node, run**:
```bash
accelerate launch \
--multi_gpu \
--num_machines 4 \
--num_processes 32 \
--main_process_ip $MASTER_IP \
--main_process_port 29500 \
--machine_rank $NODE_RANK \
-m lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu
```
**Environment variables**:
- `MASTER_IP`: IP of rank 0 node
- `NODE_RANK`: 0, 1, 2, 3 for each node
## Best Practices
### 1. Start Small
Test on small sample first:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-70b-hf,parallelize=True \
--tasks mmlu \
--limit 100 # Just 100 samples
```
### 2. Monitor GPU Usage
```bash
# Terminal 1: Run evaluation
lm_eval --model hf ...
# Terminal 2: Monitor
watch -n 1 nvidia-smi
```
Look for:
- GPU utilization > 90%
- Memory usage stable
- All GPUs active
### 3. Optimize Batch Size
```bash
# Auto batch size (recommended)
--batch_size auto
# Or tune manually
--batch_size 16 # Start here
--batch_size 32 # Increase if memory allows
```
### 4. Use Mixed Precision
```bash
--model_args dtype=bfloat16 # Faster, less memory
```
### 5. Check Communication
For data parallelism, check network bandwidth:
```bash
# Should see InfiniBand or high-speed network
nvidia-smi topo -m
```
## Troubleshooting
### "CUDA out of memory"
**Solutions**:
1. Increase tensor parallelism:
```bash
--model_args tensor_parallel_size=8 # Was 4
```
2. Reduce batch size:
```bash
--batch_size 4 # Was 16
```
3. Lower precision:
```bash
--model_args dtype=int8 # Quantization
```
### "NCCL error" or Hanging
**Check**:
1. All GPUs visible: `nvidia-smi`
2. NCCL installed: `python -c "import torch; print(torch.cuda.nccl.version())"`
3. Network connectivity between nodes
**Fix**:
```bash
export NCCL_DEBUG=INFO # Enable debug logging
export NCCL_IB_DISABLE=0 # Use InfiniBand if available
```
### Slow Evaluation
**Possible causes**:
1. **Data loading bottleneck**: Preprocess dataset
2. **Low GPU utilization**: Increase batch size
3. **Communication overhead**: Reduce parallelism degree
**Profile**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu \
--limit 100 \
--log_samples # Check timing
```
### GPUs Imbalanced
**Symptom**: GPU 0 at 100%, others at 50%
**Solution**: Use `device_map_option=balanced`:
```bash
--model_args parallelize=True,device_map_option=balanced
```
## Example Configurations
### Small Model (7B) - Fast Evaluation
```bash
# 8 A100s, data parallel
accelerate launch --multi_gpu --num_processes 8 \
-m lm_eval --model hf \
--model_args \
pretrained=meta-llama/Llama-2-7b-hf,\
dtype=bfloat16 \
--tasks mmlu,gsm8k,hellaswag,arc_challenge \
--num_fewshot 5 \
--batch_size 32
# Time: ~30 minutes
```
### Large Model (70B) - vLLM
```bash
# 8 H100s, tensor parallel
lm_eval --model vllm \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
tensor_parallel_size=8,\
dtype=auto,\
gpu_memory_utilization=0.9 \
--tasks mmlu,gsm8k,humaneval \
--num_fewshot 5 \
--batch_size auto
# Time: ~1 hour
```
### Very Large Model (175B+)
**Requires specialized setup - contact framework maintainers**
## References
- HuggingFace Accelerate: https://huggingface.co/docs/accelerate/
- vLLM docs: https://docs.vllm.ai/
- NeMo docs: https://docs.nvidia.com/nemo-framework/
- lm-eval distributed guide: `docs/model_guide.md`

View file

@ -0,0 +1,386 @@
---
name: nemo-curator
description: GPU-accelerated data curation for LLM training. Supports text/image/video/audio. Features fuzzy deduplication (16× faster), quality filtering (30+ heuristics), semantic deduplication, PII redaction, NSFW detection. Scales across GPUs with RAPIDS. Use for preparing high-quality training datasets, cleaning web data, or deduplicating large corpora.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [nemo-curator, cudf, dask, rapids]
metadata:
hermes:
tags: [Data Processing, NeMo Curator, Data Curation, GPU Acceleration, Deduplication, Quality Filtering, NVIDIA, RAPIDS, PII Redaction, Multimodal, LLM Training Data]
---
# NeMo Curator - GPU-Accelerated Data Curation
NVIDIA's toolkit for preparing high-quality training data for LLMs.
## When to use NeMo Curator
**Use NeMo Curator when:**
- Preparing LLM training data from web scrapes (Common Crawl)
- Need fast deduplication (16× faster than CPU)
- Curating multi-modal datasets (text, images, video, audio)
- Filtering low-quality or toxic content
- Scaling data processing across GPU cluster
**Performance**:
- **16× faster** fuzzy deduplication (8TB RedPajama v2)
- **40% lower TCO** vs CPU alternatives
- **Near-linear scaling** across GPU nodes
**Use alternatives instead**:
- **datatrove**: CPU-based, open-source data processing
- **dolma**: Allen AI's data toolkit
- **Ray Data**: General ML data processing (no curation focus)
## Quick start
### Installation
```bash
# Text curation (CUDA 12)
uv pip install "nemo-curator[text_cuda12]"
# All modalities
uv pip install "nemo-curator[all_cuda12]"
# CPU-only (slower)
uv pip install "nemo-curator[cpu]"
```
### Basic text curation pipeline
```python
from nemo_curator import ScoreFilter, Modify
from nemo_curator.datasets import DocumentDataset
import pandas as pd
# Load data
df = pd.DataFrame({"text": ["Good document", "Bad doc", "Excellent text"]})
dataset = DocumentDataset(df)
# Quality filtering
def quality_score(doc):
return len(doc["text"].split()) > 5 # Filter short docs
filtered = ScoreFilter(quality_score)(dataset)
# Deduplication
from nemo_curator.modules import ExactDuplicates
deduped = ExactDuplicates()(filtered)
# Save
deduped.to_parquet("curated_data/")
```
## Data curation pipeline
### Stage 1: Quality filtering
```python
from nemo_curator.filters import (
WordCountFilter,
RepeatedLinesFilter,
UrlRatioFilter,
NonAlphaNumericFilter
)
# Apply 30+ heuristic filters
from nemo_curator import ScoreFilter
# Word count filter
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
# Remove repetitive content
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
# URL ratio filter
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
```
### Stage 2: Deduplication
**Exact deduplication**:
```python
from nemo_curator.modules import ExactDuplicates
# Remove exact duplicates
deduped = ExactDuplicates(id_field="id", text_field="text")(dataset)
```
**Fuzzy deduplication** (16× faster on GPU):
```python
from nemo_curator.modules import FuzzyDuplicates
# MinHash + LSH deduplication
fuzzy_dedup = FuzzyDuplicates(
id_field="id",
text_field="text",
num_hashes=260, # MinHash parameters
num_buckets=20,
hash_method="md5"
)
deduped = fuzzy_dedup(dataset)
```
**Semantic deduplication**:
```python
from nemo_curator.modules import SemanticDuplicates
# Embedding-based deduplication
semantic_dedup = SemanticDuplicates(
id_field="id",
text_field="text",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
threshold=0.8 # Cosine similarity threshold
)
deduped = semantic_dedup(dataset)
```
### Stage 3: PII redaction
```python
from nemo_curator.modules import Modify
from nemo_curator.modifiers import PIIRedactor
# Redact personally identifiable information
pii_redactor = PIIRedactor(
supported_entities=["EMAIL_ADDRESS", "PHONE_NUMBER", "PERSON", "LOCATION"],
anonymize_action="replace" # or "redact"
)
redacted = Modify(pii_redactor)(dataset)
```
### Stage 4: Classifier filtering
```python
from nemo_curator.classifiers import QualityClassifier
# Quality classification
quality_clf = QualityClassifier(
model_path="nvidia/quality-classifier-deberta",
batch_size=256,
device="cuda"
)
# Filter low-quality documents
high_quality = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
```
## GPU acceleration
### GPU vs CPU performance
| Operation | CPU (16 cores) | GPU (A100) | Speedup |
|-----------|----------------|------------|---------|
| Fuzzy dedup (8TB) | 120 hours | 7.5 hours | 16× |
| Exact dedup (1TB) | 8 hours | 0.5 hours | 16× |
| Quality filtering | 2 hours | 0.2 hours | 10× |
### Multi-GPU scaling
```python
from nemo_curator import get_client
import dask_cuda
# Initialize GPU cluster
client = get_client(cluster_type="gpu", n_workers=8)
# Process with 8 GPUs
deduped = FuzzyDuplicates(...)(dataset)
```
## Multi-modal curation
### Image curation
```python
from nemo_curator.image import (
AestheticFilter,
NSFWFilter,
CLIPEmbedder
)
# Aesthetic scoring
aesthetic_filter = AestheticFilter(threshold=5.0)
filtered_images = aesthetic_filter(image_dataset)
# NSFW detection
nsfw_filter = NSFWFilter(threshold=0.9)
safe_images = nsfw_filter(filtered_images)
# Generate CLIP embeddings
clip_embedder = CLIPEmbedder(model="openai/clip-vit-base-patch32")
image_embeddings = clip_embedder(safe_images)
```
### Video curation
```python
from nemo_curator.video import (
SceneDetector,
ClipExtractor,
InternVideo2Embedder
)
# Detect scenes
scene_detector = SceneDetector(threshold=27.0)
scenes = scene_detector(video_dataset)
# Extract clips
clip_extractor = ClipExtractor(min_duration=2.0, max_duration=10.0)
clips = clip_extractor(scenes)
# Generate embeddings
video_embedder = InternVideo2Embedder()
video_embeddings = video_embedder(clips)
```
### Audio curation
```python
from nemo_curator.audio import (
ASRInference,
WERFilter,
DurationFilter
)
# ASR transcription
asr = ASRInference(model="nvidia/stt_en_fastconformer_hybrid_large_pc")
transcribed = asr(audio_dataset)
# Filter by WER (word error rate)
wer_filter = WERFilter(max_wer=0.3)
high_quality_audio = wer_filter(transcribed)
# Duration filtering
duration_filter = DurationFilter(min_duration=1.0, max_duration=30.0)
filtered_audio = duration_filter(high_quality_audio)
```
## Common patterns
### Web scrape curation (Common Crawl)
```python
from nemo_curator import ScoreFilter, Modify
from nemo_curator.filters import *
from nemo_curator.modules import *
from nemo_curator.datasets import DocumentDataset
# Load Common Crawl data
dataset = DocumentDataset.read_parquet("common_crawl/*.parquet")
# Pipeline
pipeline = [
# 1. Quality filtering
WordCountFilter(min_words=100, max_words=50000),
RepeatedLinesFilter(max_repeated_line_fraction=0.2),
SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3),
UrlRatioFilter(max_url_ratio=0.3),
# 2. Language filtering
LanguageIdentificationFilter(target_languages=["en"]),
# 3. Deduplication
ExactDuplicates(id_field="id", text_field="text"),
FuzzyDuplicates(id_field="id", text_field="text", num_hashes=260),
# 4. PII redaction
PIIRedactor(),
# 5. NSFW filtering
NSFWClassifier(threshold=0.8)
]
# Execute
for stage in pipeline:
dataset = stage(dataset)
# Save
dataset.to_parquet("curated_common_crawl/")
```
### Distributed processing
```python
from nemo_curator import get_client
from dask_cuda import LocalCUDACluster
# Multi-GPU cluster
cluster = LocalCUDACluster(n_workers=8)
client = get_client(cluster=cluster)
# Process large dataset
dataset = DocumentDataset.read_parquet("s3://large_dataset/*.parquet")
deduped = FuzzyDuplicates(...)(dataset)
# Cleanup
client.close()
cluster.close()
```
## Performance benchmarks
### Fuzzy deduplication (8TB RedPajama v2)
- **CPU (256 cores)**: 120 hours
- **GPU (8× A100)**: 7.5 hours
- **Speedup**: 16×
### Exact deduplication (1TB)
- **CPU (64 cores)**: 8 hours
- **GPU (4× A100)**: 0.5 hours
- **Speedup**: 16×
### Quality filtering (100GB)
- **CPU (32 cores)**: 2 hours
- **GPU (2× A100)**: 0.2 hours
- **Speedup**: 10×
## Cost comparison
**CPU-based curation** (AWS c5.18xlarge × 10):
- Cost: $3.60/hour × 10 = $36/hour
- Time for 8TB: 120 hours
- **Total**: $4,320
**GPU-based curation** (AWS p4d.24xlarge × 2):
- Cost: $32.77/hour × 2 = $65.54/hour
- Time for 8TB: 7.5 hours
- **Total**: $491.55
**Savings**: 89% reduction ($3,828 saved)
## Supported data formats
- **Input**: Parquet, JSONL, CSV
- **Output**: Parquet (recommended), JSONL
- **WebDataset**: TAR archives for multi-modal
## Use cases
**Production deployments**:
- NVIDIA used NeMo Curator to prepare Nemotron-4 training data
- Open-source datasets curated: RedPajama v2, The Pile
## References
- **[Filtering Guide](references/filtering.md)** - 30+ quality filters, heuristics
- **[Deduplication Guide](references/deduplication.md)** - Exact, fuzzy, semantic methods
## Resources
- **GitHub**: https://github.com/NVIDIA/NeMo-Curator ⭐ 500+
- **Docs**: https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/
- **Version**: 0.4.0+
- **License**: Apache 2.0

View file

@ -0,0 +1,87 @@
# Deduplication Guide
Complete guide to exact, fuzzy, and semantic deduplication.
## Exact deduplication
Remove documents with identical content.
```python
from nemo_curator.modules import ExactDuplicates
# Exact deduplication
exact_dedup = ExactDuplicates(
id_field="id",
text_field="text",
hash_method="md5" # or "sha256"
)
deduped = exact_dedup(dataset)
```
**Performance**: ~16× faster on GPU vs CPU
## Fuzzy deduplication
Remove near-duplicate documents using MinHash + LSH.
```python
from nemo_curator.modules import FuzzyDuplicates
fuzzy_dedup = FuzzyDuplicates(
id_field="id",
text_field="text",
num_hashes=260, # MinHash permutations (more = accurate)
num_buckets=20, # LSH buckets (more = faster, less recall)
hash_method="md5",
jaccard_threshold=0.8 # Similarity threshold
)
deduped = fuzzy_dedup(dataset)
```
**Parameters**:
- `num_hashes`: 128-512 (default 260)
- `num_buckets`: 10-50 (default 20)
- `jaccard_threshold`: 0.7-0.9 (default 0.8)
**Performance**: 16× faster on 8TB dataset (120h → 7.5h)
## Semantic deduplication
Remove semantically similar documents using embeddings.
```python
from nemo_curator.modules import SemanticDuplicates
semantic_dedup = SemanticDuplicates(
id_field="id",
text_field="text",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
embedding_batch_size=256,
threshold=0.85, # Cosine similarity threshold
device="cuda"
)
deduped = semantic_dedup(dataset)
```
**Models**:
- `all-MiniLM-L6-v2`: Fast, 384 dims
- `all-mpnet-base-v2`: Better quality, 768 dims
- Custom models supported
## Comparison
| Method | Speed | Recall | Use Case |
|--------|-------|--------|----------|
| Exact | Fastest | 100% | Exact matches only |
| Fuzzy | Fast | ~95% | Near-duplicates (recommended) |
| Semantic | Slow | ~90% | Paraphrases, rewrites |
## Best practices
1. **Start with exact dedup** - Remove obvious duplicates
2. **Use fuzzy for large datasets** - Best speed/quality trade-off
3. **Semantic for high-value data** - Expensive but thorough
4. **GPU acceleration required** - 10-16× speedup

View file

@ -0,0 +1,102 @@
# Quality Filtering Guide
Complete guide to NeMo Curator's 30+ quality filters.
## Text-based filters
### Word count
```python
from nemo_curator.filters import WordCountFilter
# Filter by word count
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
```
### Repeated content
```python
from nemo_curator.filters import RepeatedLinesFilter
# Remove documents with >30% repeated lines
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
```
### Symbol ratio
```python
from nemo_curator.filters import SymbolToWordRatioFilter
# Remove documents with too many symbols
dataset = dataset.filter(SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3))
```
### URL ratio
```python
from nemo_curator.filters import UrlRatioFilter
# Remove documents with many URLs
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
```
## Language filtering
```python
from nemo_curator.filters import LanguageIdentificationFilter
# Keep only English documents
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en"]))
# Multiple languages
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en", "es", "fr"]))
```
## Classifier-based filtering
### Quality classifier
```python
from nemo_curator.classifiers import QualityClassifier
quality_clf = QualityClassifier(
model_path="nvidia/quality-classifier-deberta",
batch_size=256,
device="cuda"
)
# Filter low-quality (threshold > 0.5 = high quality)
dataset = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
```
### NSFW classifier
```python
from nemo_curator.classifiers import NSFWClassifier
nsfw_clf = NSFWClassifier(threshold=0.9, device="cuda")
# Remove NSFW content
dataset = dataset.filter(lambda doc: nsfw_clf(doc["text"]) < 0.9)
```
## Heuristic filters
Full list of 30+ filters:
- WordCountFilter
- RepeatedLinesFilter
- UrlRatioFilter
- SymbolToWordRatioFilter
- NonAlphaNumericFilter
- BulletsFilter
- WhiteSpaceFilter
- ParenthesesFilter
- LongWordFilter
- And 20+ more...
## Best practices
1. **Apply cheap filters first** - Word count before GPU classifiers
2. **Tune thresholds on sample** - Test on 10k docs before full run
3. **Use GPU classifiers sparingly** - Expensive but effective
4. **Chain filters efficiently** - Order by cost (cheap → expensive)

View file

@ -0,0 +1,389 @@
---
name: sparse-autoencoder-training
description: Provides guidance for training and analyzing Sparse Autoencoders (SAEs) using SAELens to decompose neural network activations into interpretable features. Use when discovering interpretable features, analyzing superposition, or studying monosemantic representations in language models.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [sae-lens>=6.0.0, transformer-lens>=2.0.0, torch>=2.0.0]
metadata:
hermes:
tags: [Sparse Autoencoders, SAE, Mechanistic Interpretability, Feature Discovery, Superposition]
---
# SAELens: Sparse Autoencoders for Mechanistic Interpretability
SAELens is the primary library for training and analyzing Sparse Autoencoders (SAEs) - a technique for decomposing polysemantic neural network activations into sparse, interpretable features. Based on Anthropic's groundbreaking research on monosemanticity.
**GitHub**: [jbloomAus/SAELens](https://github.com/jbloomAus/SAELens) (1,100+ stars)
## The Problem: Polysemanticity & Superposition
Individual neurons in neural networks are **polysemantic** - they activate in multiple, semantically distinct contexts. This happens because models use **superposition** to represent more features than they have neurons, making interpretability difficult.
**SAEs solve this** by decomposing dense activations into sparse, monosemantic features - typically only a small number of features activate for any given input, and each feature corresponds to an interpretable concept.
## When to Use SAELens
**Use SAELens when you need to:**
- Discover interpretable features in model activations
- Understand what concepts a model has learned
- Study superposition and feature geometry
- Perform feature-based steering or ablation
- Analyze safety-relevant features (deception, bias, harmful content)
**Consider alternatives when:**
- You need basic activation analysis → Use **TransformerLens** directly
- You want causal intervention experiments → Use **pyvene** or **TransformerLens**
- You need production steering → Consider direct activation engineering
## Installation
```bash
pip install sae-lens
```
Requirements: Python 3.10+, transformer-lens>=2.0.0
## Core Concepts
### What SAEs Learn
SAEs are trained to reconstruct model activations through a sparse bottleneck:
```
Input Activation → Encoder → Sparse Features → Decoder → Reconstructed Activation
(d_model) ↓ (d_sae >> d_model) ↓ (d_model)
sparsity reconstruction
penalty loss
```
**Loss Function**: `MSE(original, reconstructed) + L1_coefficient × L1(features)`
### Key Validation (Anthropic Research)
In "Towards Monosemanticity", human evaluators found **70% of SAE features genuinely interpretable**. Features discovered include:
- DNA sequences, legal language, HTTP requests
- Hebrew text, nutrition statements, code syntax
- Sentiment, named entities, grammatical structures
## Workflow 1: Loading and Analyzing Pre-trained SAEs
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
# 1. Load model and pre-trained SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# 2. Get model activations
tokens = model.to_tokens("The capital of France is Paris")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [batch, pos, d_model]
# 3. Encode to SAE features
sae_features = sae.encode(activations) # [batch, pos, d_sae]
print(f"Active features: {(sae_features > 0).sum()}")
# 4. Find top features for each position
for pos in range(tokens.shape[1]):
top_features = sae_features[0, pos].topk(5)
token = model.to_str_tokens(tokens[0, pos:pos+1])[0]
print(f"Token '{token}': features {top_features.indices.tolist()}")
# 5. Reconstruct activations
reconstructed = sae.decode(sae_features)
reconstruction_error = (activations - reconstructed).norm()
```
### Available Pre-trained SAEs
| Release | Model | Layers |
|---------|-------|--------|
| `gpt2-small-res-jb` | GPT-2 Small | Multiple residual streams |
| `gemma-2b-res` | Gemma 2B | Residual streams |
| Various on HuggingFace | Search tag `saelens` | Various |
### Checklist
- [ ] Load model with TransformerLens
- [ ] Load matching SAE for target layer
- [ ] Encode activations to sparse features
- [ ] Identify top-activating features per token
- [ ] Validate reconstruction quality
## Workflow 2: Training a Custom SAE
### Step-by-Step
```python
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner
# 1. Configure training
cfg = LanguageModelSAERunnerConfig(
# Model
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768, # Model dimension
# SAE architecture
architecture="standard", # or "gated", "topk"
d_sae=768 * 8, # Expansion factor of 8
activation_fn="relu",
# Training
lr=4e-4,
l1_coefficient=8e-5, # Sparsity penalty
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=100_000_000,
# Data
dataset_path="monology/pile-uncopyrighted",
context_size=128,
# Logging
log_to_wandb=True,
wandb_project="sae-training",
# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,
)
# 2. Train
trainer = SAETrainingRunner(cfg)
sae = trainer.run()
# 3. Evaluate
print(f"L0 (avg active features): {trainer.metrics['l0']}")
print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
```
### Key Hyperparameters
| Parameter | Typical Value | Effect |
|-----------|---------------|--------|
| `d_sae` | 4-16× d_model | More features, higher capacity |
| `l1_coefficient` | 5e-5 to 1e-4 | Higher = sparser, less accurate |
| `lr` | 1e-4 to 1e-3 | Standard optimizer LR |
| `l1_warm_up_steps` | 500-2000 | Prevents early feature death |
### Evaluation Metrics
| Metric | Target | Meaning |
|--------|--------|---------|
| **L0** | 50-200 | Average active features per token |
| **CE Loss Score** | 80-95% | Cross-entropy recovered vs original |
| **Dead Features** | <5% | Features that never activate |
| **Explained Variance** | >90% | Reconstruction quality |
### Checklist
- [ ] Choose target layer and hook point
- [ ] Set expansion factor (d_sae = 4-16× d_model)
- [ ] Tune L1 coefficient for desired sparsity
- [ ] Enable L1 warm-up to prevent dead features
- [ ] Monitor metrics during training (W&B)
- [ ] Validate L0 and CE loss recovery
- [ ] Check dead feature ratio
## Workflow 3: Feature Analysis and Steering
### Analyzing Individual Features
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Find what activates a specific feature
feature_idx = 1234
test_texts = [
"The scientist conducted an experiment",
"I love chocolate cake",
"The code compiles successfully",
"Paris is beautiful in spring",
]
for text in test_texts:
tokens = model.to_tokens(text)
_, cache = model.run_with_cache(tokens)
features = sae.encode(cache["resid_pre", 8])
activation = features[0, :, feature_idx].max().item()
print(f"{activation:.3f}: {text}")
```
### Feature Steering
```python
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
"""Add SAE feature direction to residual stream."""
tokens = model.to_tokens(prompt)
# Get feature direction from decoder
feature_direction = sae.W_dec[feature_idx] # [d_model]
def steering_hook(activation, hook):
# Add scaled feature direction at all positions
activation += strength * feature_direction
return activation
# Generate with steering
output = model.generate(
tokens,
max_new_tokens=50,
fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
)
return model.to_string(output[0])
```
### Feature Attribution
```python
# Which features most affect a specific output?
tokens = model.to_tokens("The capital of France is")
_, cache = model.run_with_cache(tokens)
# Get features at final position
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]
# Get logit attribution per feature
# Feature contribution = feature_activation × decoder_weight × unembedding
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, vocab]
# Contribution to "Paris" logit
paris_token = model.to_single_token(" Paris")
feature_contributions = features * (W_dec @ W_U[:, paris_token])
top_features = feature_contributions.topk(10)
print("Top features for 'Paris' prediction:")
for idx, val in zip(top_features.indices, top_features.values):
print(f" Feature {idx.item()}: {val.item():.3f}")
```
## Common Issues & Solutions
### Issue: High dead feature ratio
```python
# WRONG: No warm-up, features die early
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4,
l1_warm_up_steps=0, # Bad!
)
# RIGHT: Warm-up L1 penalty
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=8e-5,
l1_warm_up_steps=1000, # Gradually increase
use_ghost_grads=True, # Revive dead features
)
```
### Issue: Poor reconstruction (low CE recovery)
```python
# Reduce sparsity penalty
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=5e-5, # Lower = better reconstruction
d_sae=768 * 16, # More capacity
)
```
### Issue: Features not interpretable
```python
# Increase sparsity (higher L1)
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4, # Higher = sparser, more interpretable
)
# Or use TopK architecture
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
)
```
### Issue: Memory errors during training
```python
cfg = LanguageModelSAERunnerConfig(
train_batch_size_tokens=2048, # Reduce batch size
store_batch_size_prompts=4, # Fewer prompts in buffer
n_batches_in_buffer=8, # Smaller activation buffer
)
```
## Integration with Neuronpedia
Browse pre-trained SAE features at [neuronpedia.org](https://neuronpedia.org):
```python
# Features are indexed by SAE ID
# Example: gpt2-small layer 8 feature 1234
# → neuronpedia.org/gpt2-small/8-res-jb/1234
```
## Key Classes Reference
| Class | Purpose |
|-------|---------|
| `SAE` | Sparse Autoencoder model |
| `LanguageModelSAERunnerConfig` | Training configuration |
| `SAETrainingRunner` | Training loop manager |
| `ActivationsStore` | Activation collection and batching |
| `HookedSAETransformer` | TransformerLens + SAE integration |
## Reference Documentation
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
| File | Contents |
|------|----------|
| [references/README.md](references/README.md) | Overview and quick start guide |
| [references/api.md](references/api.md) | Complete API reference for SAE, TrainingSAE, configurations |
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for training, analysis, steering |
## External Resources
### Tutorials
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- [Training a Sparse Autoencoder](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
- [ARENA SAE Curriculum](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab)
### Papers
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
- [Sparse Autoencoders Find Highly Interpretable Features](https://arxiv.org/abs/2309.08600) - Cunningham et al. (ICLR 2024)
### Official Documentation
- [SAELens Docs](https://jbloomaus.github.io/SAELens/)
- [Neuronpedia](https://neuronpedia.org) - Feature browser
## SAE Architectures
| Architecture | Description | Use Case |
|--------------|-------------|----------|
| **Standard** | ReLU + L1 penalty | General purpose |
| **Gated** | Learned gating mechanism | Better sparsity control |
| **TopK** | Exactly K active features | Consistent sparsity |
```python
# TopK SAE (exactly 50 features active)
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50},
)
```

View file

@ -0,0 +1,70 @@
# SAELens Reference Documentation
This directory contains comprehensive reference materials for SAELens.
## Contents
- [api.md](api.md) - Complete API reference for SAE, TrainingSAE, and configuration classes
- [tutorials.md](tutorials.md) - Step-by-step tutorials for training and analyzing SAEs
- [papers.md](papers.md) - Key research papers on sparse autoencoders
## Quick Links
- **GitHub Repository**: https://github.com/jbloomAus/SAELens
- **Neuronpedia**: https://neuronpedia.org (browse pre-trained SAE features)
- **HuggingFace SAEs**: Search for tag `saelens`
## Installation
```bash
pip install sae-lens
```
Requirements: Python 3.10+, transformer-lens>=2.0.0
## Basic Usage
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
# Load model and SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Encode activations to sparse features
tokens = model.to_tokens("Hello world")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations) # Sparse feature activations
reconstructed = sae.decode(features) # Reconstructed activations
```
## Key Concepts
### Sparse Autoencoders
SAEs decompose dense neural activations into sparse, interpretable features:
- **Encoder**: Maps d_model → d_sae (typically 4-16x expansion)
- **ReLU/TopK**: Enforces sparsity
- **Decoder**: Reconstructs original activations
### Training Loss
`Loss = MSE(original, reconstructed) + L1_coefficient × L1(features)`
### Key Metrics
- **L0**: Average number of active features (target: 50-200)
- **CE Loss Score**: Cross-entropy recovered vs original model (target: 80-95%)
- **Dead Features**: Features that never activate (target: <5%)
## Available Pre-trained SAEs
| Release | Model | Description |
|---------|-------|-------------|
| `gpt2-small-res-jb` | GPT-2 Small | Residual stream SAEs |
| `gemma-2b-res` | Gemma 2B | Residual stream SAEs |
| Various | Search HuggingFace | Community-trained SAEs |

View file

@ -0,0 +1,333 @@
# SAELens API Reference
## SAE Class
The core class representing a Sparse Autoencoder.
### Loading Pre-trained SAEs
```python
from sae_lens import SAE
# From official releases
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# From HuggingFace
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="username/repo-name",
sae_id="path/to/sae",
device="cuda"
)
# From local disk
sae = SAE.load_from_disk("/path/to/sae", device="cuda")
```
### SAE Attributes
| Attribute | Shape | Description |
|-----------|-------|-------------|
| `W_enc` | [d_in, d_sae] | Encoder weights |
| `W_dec` | [d_sae, d_in] | Decoder weights |
| `b_enc` | [d_sae] | Encoder bias |
| `b_dec` | [d_in] | Decoder bias |
| `cfg` | SAEConfig | Configuration object |
### Core Methods
#### encode()
```python
# Encode activations to sparse features
features = sae.encode(activations)
# Input: [batch, pos, d_in]
# Output: [batch, pos, d_sae]
```
#### decode()
```python
# Reconstruct activations from features
reconstructed = sae.decode(features)
# Input: [batch, pos, d_sae]
# Output: [batch, pos, d_in]
```
#### forward()
```python
# Full forward pass (encode + decode)
reconstructed = sae(activations)
# Returns reconstructed activations
```
#### save_model()
```python
sae.save_model("/path/to/save")
```
---
## SAEConfig
Configuration class for SAE architecture and training context.
### Key Parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `d_in` | int | Input dimension (model's d_model) |
| `d_sae` | int | SAE hidden dimension |
| `architecture` | str | "standard", "gated", "jumprelu", "topk" |
| `activation_fn_str` | str | Activation function name |
| `model_name` | str | Source model name |
| `hook_name` | str | Hook point in model |
| `normalize_activations` | str | Normalization method |
| `dtype` | str | Data type |
| `device` | str | Device |
### Accessing Config
```python
print(sae.cfg.d_in) # 768 for GPT-2 small
print(sae.cfg.d_sae) # e.g., 24576 (32x expansion)
print(sae.cfg.hook_name) # e.g., "blocks.8.hook_resid_pre"
```
---
## LanguageModelSAERunnerConfig
Comprehensive configuration for training SAEs.
### Example Configuration
```python
from sae_lens import LanguageModelSAERunnerConfig
cfg = LanguageModelSAERunnerConfig(
# Model and hook
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768,
# SAE architecture
architecture="standard", # "standard", "gated", "jumprelu", "topk"
d_sae=768 * 8, # Expansion factor
activation_fn="relu",
# Training hyperparameters
lr=4e-4,
l1_coefficient=8e-5,
lp_norm=1.0,
lr_scheduler_name="constant",
lr_warm_up_steps=500,
# Sparsity control
l1_warm_up_steps=1000,
use_ghost_grads=True,
feature_sampling_window=1000,
dead_feature_window=5000,
dead_feature_threshold=1e-8,
# Data
dataset_path="monology/pile-uncopyrighted",
streaming=True,
context_size=128,
# Batch sizes
train_batch_size_tokens=4096,
store_batch_size_prompts=16,
n_batches_in_buffer=64,
# Training duration
training_tokens=100_000_000,
# Logging
log_to_wandb=True,
wandb_project="sae-training",
wandb_log_frequency=100,
# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,
# Hardware
device="cuda",
dtype="float32",
)
```
### Key Parameters Explained
#### Architecture Parameters
| Parameter | Description |
|-----------|-------------|
| `architecture` | SAE type: "standard", "gated", "jumprelu", "topk" |
| `d_sae` | Hidden dimension (or use `expansion_factor`) |
| `expansion_factor` | Alternative to d_sae: d_sae = d_in × expansion_factor |
| `activation_fn` | "relu", "topk", etc. |
| `activation_fn_kwargs` | Dict for activation params (e.g., {"k": 50} for topk) |
#### Sparsity Parameters
| Parameter | Description |
|-----------|-------------|
| `l1_coefficient` | L1 penalty weight (higher = sparser) |
| `l1_warm_up_steps` | Steps to ramp up L1 penalty |
| `use_ghost_grads` | Apply gradients to dead features |
| `dead_feature_threshold` | Activation threshold for "dead" |
| `dead_feature_window` | Steps to check for dead features |
#### Learning Rate Parameters
| Parameter | Description |
|-----------|-------------|
| `lr` | Base learning rate |
| `lr_scheduler_name` | "constant", "cosineannealing", etc. |
| `lr_warm_up_steps` | LR warmup steps |
| `lr_decay_steps` | Steps for LR decay |
---
## SAETrainingRunner
Main class for executing training.
### Basic Training
```python
from sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig
cfg = LanguageModelSAERunnerConfig(...)
runner = SAETrainingRunner(cfg)
sae = runner.run()
```
### Accessing Training Metrics
```python
# During training, metrics logged to W&B include:
# - l0: Average active features
# - ce_loss_score: Cross-entropy recovery
# - mse_loss: Reconstruction loss
# - l1_loss: Sparsity loss
# - dead_features: Count of dead features
```
---
## ActivationsStore
Manages activation collection and batching.
### Basic Usage
```python
from sae_lens import ActivationsStore
store = ActivationsStore.from_sae(
model=model,
sae=sae,
store_batch_size_prompts=8,
train_batch_size_tokens=4096,
n_batches_in_buffer=32,
device="cuda",
)
# Get batch of activations
activations = store.get_batch_tokens()
```
---
## HookedSAETransformer
Integration of SAEs with TransformerLens models.
### Basic Usage
```python
from sae_lens import HookedSAETransformer
# Load model with SAE
model = HookedSAETransformer.from_pretrained("gpt2-small")
model.add_sae(sae)
# Run with SAE in the loop
output = model.run_with_saes(tokens, saes=[sae])
# Cache with SAE activations
output, cache = model.run_with_cache_with_saes(tokens, saes=[sae])
```
---
## SAE Architectures
### Standard (ReLU + L1)
```python
cfg = LanguageModelSAERunnerConfig(
architecture="standard",
activation_fn="relu",
l1_coefficient=8e-5,
)
```
### Gated
```python
cfg = LanguageModelSAERunnerConfig(
architecture="gated",
)
```
### TopK
```python
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
)
```
### JumpReLU (State-of-the-art)
```python
cfg = LanguageModelSAERunnerConfig(
architecture="jumprelu",
)
```
---
## Utility Functions
### Upload to HuggingFace
```python
from sae_lens import upload_saes_to_huggingface
upload_saes_to_huggingface(
saes=[sae],
repo_id="username/my-saes",
token="hf_token",
)
```
### Neuronpedia Integration
```python
# Features can be viewed on Neuronpedia
# URL format: neuronpedia.org/{model}/{layer}-{sae_type}/{feature_id}
# Example: neuronpedia.org/gpt2-small/8-res-jb/1234
```

View file

@ -0,0 +1,318 @@
# SAELens Tutorials
## Tutorial 1: Loading and Analyzing Pre-trained SAEs
### Goal
Load a pre-trained SAE and analyze which features activate on specific inputs.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
# 1. Load model and SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
print(f"SAE input dim: {sae.cfg.d_in}")
print(f"SAE hidden dim: {sae.cfg.d_sae}")
print(f"Expansion factor: {sae.cfg.d_sae / sae.cfg.d_in:.1f}x")
# 2. Get model activations
prompt = "The capital of France is Paris"
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [1, seq_len, 768]
# 3. Encode to SAE features
features = sae.encode(activations) # [1, seq_len, d_sae]
# 4. Analyze sparsity
active_per_token = (features > 0).sum(dim=-1)
print(f"Average active features per token: {active_per_token.float().mean():.1f}")
# 5. Find top features for each token
str_tokens = model.to_str_tokens(prompt)
for pos in range(len(str_tokens)):
top_features = features[0, pos].topk(5)
print(f"\nToken '{str_tokens[pos]}':")
for feat_idx, feat_val in zip(top_features.indices, top_features.values):
print(f" Feature {feat_idx.item()}: {feat_val.item():.3f}")
# 6. Check reconstruction quality
reconstructed = sae.decode(features)
mse = ((activations - reconstructed) ** 2).mean()
print(f"\nReconstruction MSE: {mse.item():.6f}")
```
---
## Tutorial 2: Training a Custom SAE
### Goal
Train a Sparse Autoencoder on GPT-2 activations.
### Step-by-Step
```python
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
# 1. Configure training
cfg = LanguageModelSAERunnerConfig(
# Model
model_name="gpt2-small",
hook_name="blocks.6.hook_resid_pre",
hook_layer=6,
d_in=768,
# SAE architecture
architecture="standard",
d_sae=768 * 8, # 8x expansion
activation_fn="relu",
# Training
lr=4e-4,
l1_coefficient=8e-5,
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=10_000_000, # Small run for demo
# Data
dataset_path="monology/pile-uncopyrighted",
streaming=True,
context_size=128,
# Dead feature prevention
use_ghost_grads=True,
dead_feature_window=5000,
# Logging
log_to_wandb=True,
wandb_project="sae-training-demo",
# Hardware
device="cuda",
dtype="float32",
)
# 2. Train
runner = SAETrainingRunner(cfg)
sae = runner.run()
# 3. Save
sae.save_model("./my_trained_sae")
```
### Hyperparameter Tuning Guide
| If you see... | Try... |
|---------------|--------|
| High L0 (>200) | Increase `l1_coefficient` |
| Low CE recovery (<80%) | Decrease `l1_coefficient`, increase `d_sae` |
| Many dead features (>5%) | Enable `use_ghost_grads`, increase `l1_warm_up_steps` |
| Training instability | Lower `lr`, increase `lr_warm_up_steps` |
---
## Tutorial 3: Feature Attribution and Steering
### Goal
Identify which SAE features contribute to specific predictions and use them for steering.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# 1. Feature attribution for a specific prediction
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
# Target token
target_token = model.to_single_token(" Paris")
# Compute feature contributions to target logit
# contribution = feature_activation * decoder_weight * unembedding
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, d_vocab]
# Feature direction projected to vocabulary
feature_to_logit = W_dec @ W_U # [d_sae, d_vocab]
# Contribution of each feature to "Paris" at final position
feature_acts = features[0, -1] # [d_sae]
contributions = feature_acts * feature_to_logit[:, target_token]
# Top contributing features
top_features = contributions.topk(10)
print("Top features contributing to 'Paris':")
for idx, val in zip(top_features.indices, top_features.values):
print(f" Feature {idx.item()}: {val.item():.3f}")
# 2. Feature steering
def steer_with_feature(feature_idx, strength=5.0):
"""Add a feature direction to the residual stream."""
feature_direction = sae.W_dec[feature_idx] # [d_model]
def hook(activation, hook_obj):
activation[:, -1, :] += strength * feature_direction
return activation
output = model.generate(
tokens,
max_new_tokens=10,
fwd_hooks=[("blocks.8.hook_resid_pre", hook)]
)
return model.to_string(output[0])
# Try steering with top feature
top_feature_idx = top_features.indices[0].item()
print(f"\nSteering with feature {top_feature_idx}:")
print(steer_with_feature(top_feature_idx, strength=10.0))
```
---
## Tutorial 4: Feature Ablation
### Goal
Test the causal importance of features by ablating them.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
# Baseline prediction
baseline_logits = model(tokens)
target_token = model.to_single_token(" Paris")
baseline_prob = torch.softmax(baseline_logits[0, -1], dim=-1)[target_token].item()
print(f"Baseline P(Paris): {baseline_prob:.4f}")
# Get features to ablate
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
top_features = features[0, -1].topk(10).indices
# Ablate top features one by one
for feat_idx in top_features:
def ablation_hook(activation, hook, feat_idx=feat_idx):
# Encode → zero feature → decode
feats = sae.encode(activation)
feats[:, :, feat_idx] = 0
return sae.decode(feats)
ablated_logits = model.run_with_hooks(
tokens,
fwd_hooks=[("blocks.8.hook_resid_pre", ablation_hook)]
)
ablated_prob = torch.softmax(ablated_logits[0, -1], dim=-1)[target_token].item()
change = (ablated_prob - baseline_prob) / baseline_prob * 100
print(f"Ablate feature {feat_idx.item()}: P(Paris)={ablated_prob:.4f} ({change:+.1f}%)")
```
---
## Tutorial 5: Comparing Features Across Prompts
### Goal
Find which features activate consistently for a concept.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Test prompts about the same concept
prompts = [
"The Eiffel Tower is located in",
"Paris is the capital of",
"France's largest city is",
"The Louvre museum is in",
]
# Collect feature activations
all_features = []
for prompt in prompts:
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
# Take max activation across positions
max_features = features[0].max(dim=0).values
all_features.append(max_features)
all_features = torch.stack(all_features) # [n_prompts, d_sae]
# Find features that activate consistently
mean_activation = all_features.mean(dim=0)
min_activation = all_features.min(dim=0).values
# Features active in ALL prompts
consistent_features = (min_activation > 0.5).nonzero().squeeze(-1)
print(f"Features active in all prompts: {len(consistent_features)}")
# Top consistent features
top_consistent = mean_activation[consistent_features].topk(min(10, len(consistent_features)))
print("\nTop consistent features (possibly 'France/Paris' related):")
for idx, val in zip(top_consistent.indices, top_consistent.values):
feat_idx = consistent_features[idx].item()
print(f" Feature {feat_idx}: mean activation {val.item():.3f}")
```
---
## External Resources
### Official Tutorials
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- [Training SAEs](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
- [Logits Lens with Features](https://github.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
### ARENA Curriculum
Comprehensive SAE course: https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab
### Key Papers
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
- [Sparse Autoencoders Find Interpretable Features](https://arxiv.org/abs/2309.08600) - ICLR 2024

View file

@ -0,0 +1,593 @@
---
name: weights-and-biases
description: Track ML experiments with automatic logging, visualize training in real-time, optimize hyperparameters with sweeps, and manage model registry with W&B - collaborative MLOps platform
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [wandb]
metadata:
hermes:
tags: [MLOps, Weights And Biases, WandB, Experiment Tracking, Hyperparameter Tuning, Model Registry, Collaboration, Real-Time Visualization, PyTorch, TensorFlow, HuggingFace]
---
# Weights & Biases: ML Experiment Tracking & MLOps
## When to Use This Skill
Use Weights & Biases (W&B) when you need to:
- **Track ML experiments** with automatic metric logging
- **Visualize training** in real-time dashboards
- **Compare runs** across hyperparameters and configurations
- **Optimize hyperparameters** with automated sweeps
- **Manage model registry** with versioning and lineage
- **Collaborate on ML projects** with team workspaces
- **Track artifacts** (datasets, models, code) with lineage
**Users**: 200,000+ ML practitioners | **GitHub Stars**: 10.5k+ | **Integrations**: 100+
## Installation
```bash
# Install W&B
pip install wandb
# Login (creates API key)
wandb login
# Or set API key programmatically
export WANDB_API_KEY=your_api_key_here
```
## Quick Start
### Basic Experiment Tracking
```python
import wandb
# Initialize a run
run = wandb.init(
project="my-project",
config={
"learning_rate": 0.001,
"epochs": 10,
"batch_size": 32,
"architecture": "ResNet50"
}
)
# Training loop
for epoch in range(run.config.epochs):
# Your training code
train_loss = train_epoch()
val_loss = validate()
# Log metrics
wandb.log({
"epoch": epoch,
"train/loss": train_loss,
"val/loss": val_loss,
"train/accuracy": train_acc,
"val/accuracy": val_acc
})
# Finish the run
wandb.finish()
```
### With PyTorch
```python
import torch
import wandb
# Initialize
wandb.init(project="pytorch-demo", config={
"lr": 0.001,
"epochs": 10
})
# Access config
config = wandb.config
# Training loop
for epoch in range(config.epochs):
for batch_idx, (data, target) in enumerate(train_loader):
# Forward pass
output = model(data)
loss = criterion(output, target)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Log every 100 batches
if batch_idx % 100 == 0:
wandb.log({
"loss": loss.item(),
"epoch": epoch,
"batch": batch_idx
})
# Save model
torch.save(model.state_dict(), "model.pth")
wandb.save("model.pth") # Upload to W&B
wandb.finish()
```
## Core Concepts
### 1. Projects and Runs
**Project**: Collection of related experiments
**Run**: Single execution of your training script
```python
# Create/use project
run = wandb.init(
project="image-classification",
name="resnet50-experiment-1", # Optional run name
tags=["baseline", "resnet"], # Organize with tags
notes="First baseline run" # Add notes
)
# Each run has unique ID
print(f"Run ID: {run.id}")
print(f"Run URL: {run.url}")
```
### 2. Configuration Tracking
Track hyperparameters automatically:
```python
config = {
# Model architecture
"model": "ResNet50",
"pretrained": True,
# Training params
"learning_rate": 0.001,
"batch_size": 32,
"epochs": 50,
"optimizer": "Adam",
# Data params
"dataset": "ImageNet",
"augmentation": "standard"
}
wandb.init(project="my-project", config=config)
# Access config during training
lr = wandb.config.learning_rate
batch_size = wandb.config.batch_size
```
### 3. Metric Logging
```python
# Log scalars
wandb.log({"loss": 0.5, "accuracy": 0.92})
# Log multiple metrics
wandb.log({
"train/loss": train_loss,
"train/accuracy": train_acc,
"val/loss": val_loss,
"val/accuracy": val_acc,
"learning_rate": current_lr,
"epoch": epoch
})
# Log with custom x-axis
wandb.log({"loss": loss}, step=global_step)
# Log media (images, audio, video)
wandb.log({"examples": [wandb.Image(img) for img in images]})
# Log histograms
wandb.log({"gradients": wandb.Histogram(gradients)})
# Log tables
table = wandb.Table(columns=["id", "prediction", "ground_truth"])
wandb.log({"predictions": table})
```
### 4. Model Checkpointing
```python
import torch
import wandb
# Save model checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, 'checkpoint.pth')
# Upload to W&B
wandb.save('checkpoint.pth')
# Or use Artifacts (recommended)
artifact = wandb.Artifact('model', type='model')
artifact.add_file('checkpoint.pth')
wandb.log_artifact(artifact)
```
## Hyperparameter Sweeps
Automatically search for optimal hyperparameters.
### Define Sweep Configuration
```python
sweep_config = {
'method': 'bayes', # or 'grid', 'random'
'metric': {
'name': 'val/accuracy',
'goal': 'maximize'
},
'parameters': {
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
'batch_size': {
'values': [16, 32, 64, 128]
},
'optimizer': {
'values': ['adam', 'sgd', 'rmsprop']
},
'dropout': {
'distribution': 'uniform',
'min': 0.1,
'max': 0.5
}
}
}
# Initialize sweep
sweep_id = wandb.sweep(sweep_config, project="my-project")
```
### Define Training Function
```python
def train():
# Initialize run
run = wandb.init()
# Access sweep parameters
lr = wandb.config.learning_rate
batch_size = wandb.config.batch_size
optimizer_name = wandb.config.optimizer
# Build model with sweep config
model = build_model(wandb.config)
optimizer = get_optimizer(optimizer_name, lr)
# Training loop
for epoch in range(NUM_EPOCHS):
train_loss = train_epoch(model, optimizer, batch_size)
val_acc = validate(model)
# Log metrics
wandb.log({
"train/loss": train_loss,
"val/accuracy": val_acc
})
# Run sweep
wandb.agent(sweep_id, function=train, count=50) # Run 50 trials
```
### Sweep Strategies
```python
# Grid search - exhaustive
sweep_config = {
'method': 'grid',
'parameters': {
'lr': {'values': [0.001, 0.01, 0.1]},
'batch_size': {'values': [16, 32, 64]}
}
}
# Random search
sweep_config = {
'method': 'random',
'parameters': {
'lr': {'distribution': 'uniform', 'min': 0.0001, 'max': 0.1},
'dropout': {'distribution': 'uniform', 'min': 0.1, 'max': 0.5}
}
}
# Bayesian optimization (recommended)
sweep_config = {
'method': 'bayes',
'metric': {'name': 'val/loss', 'goal': 'minimize'},
'parameters': {
'lr': {'distribution': 'log_uniform', 'min': 1e-5, 'max': 1e-1}
}
}
```
## Artifacts
Track datasets, models, and other files with lineage.
### Log Artifacts
```python
# Create artifact
artifact = wandb.Artifact(
name='training-dataset',
type='dataset',
description='ImageNet training split',
metadata={'size': '1.2M images', 'split': 'train'}
)
# Add files
artifact.add_file('data/train.csv')
artifact.add_dir('data/images/')
# Log artifact
wandb.log_artifact(artifact)
```
### Use Artifacts
```python
# Download and use artifact
run = wandb.init(project="my-project")
# Download artifact
artifact = run.use_artifact('training-dataset:latest')
artifact_dir = artifact.download()
# Use the data
data = load_data(f"{artifact_dir}/train.csv")
```
### Model Registry
```python
# Log model as artifact
model_artifact = wandb.Artifact(
name='resnet50-model',
type='model',
metadata={'architecture': 'ResNet50', 'accuracy': 0.95}
)
model_artifact.add_file('model.pth')
wandb.log_artifact(model_artifact, aliases=['best', 'production'])
# Link to model registry
run.link_artifact(model_artifact, 'model-registry/production-models')
```
## Integration Examples
### HuggingFace Transformers
```python
from transformers import Trainer, TrainingArguments
import wandb
# Initialize W&B
wandb.init(project="hf-transformers")
# Training arguments with W&B
training_args = TrainingArguments(
output_dir="./results",
report_to="wandb", # Enable W&B logging
run_name="bert-finetuning",
logging_steps=100,
save_steps=500
)
# Trainer automatically logs to W&B
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()
```
### PyTorch Lightning
```python
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import wandb
# Create W&B logger
wandb_logger = WandbLogger(
project="lightning-demo",
log_model=True # Log model checkpoints
)
# Use with Trainer
trainer = Trainer(
logger=wandb_logger,
max_epochs=10
)
trainer.fit(model, datamodule=dm)
```
### Keras/TensorFlow
```python
import wandb
from wandb.keras import WandbCallback
# Initialize
wandb.init(project="keras-demo")
# Add callback
model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=10,
callbacks=[WandbCallback()] # Auto-logs metrics
)
```
## Visualization & Analysis
### Custom Charts
```python
# Log custom visualizations
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(x, y)
wandb.log({"custom_plot": wandb.Image(fig)})
# Log confusion matrix
wandb.log({"conf_mat": wandb.plot.confusion_matrix(
probs=None,
y_true=ground_truth,
preds=predictions,
class_names=class_names
)})
```
### Reports
Create shareable reports in W&B UI:
- Combine runs, charts, and text
- Markdown support
- Embeddable visualizations
- Team collaboration
## Best Practices
### 1. Organize with Tags and Groups
```python
wandb.init(
project="my-project",
tags=["baseline", "resnet50", "imagenet"],
group="resnet-experiments", # Group related runs
job_type="train" # Type of job
)
```
### 2. Log Everything Relevant
```python
# Log system metrics
wandb.log({
"gpu/util": gpu_utilization,
"gpu/memory": gpu_memory_used,
"cpu/util": cpu_utilization
})
# Log code version
wandb.log({"git_commit": git_commit_hash})
# Log data splits
wandb.log({
"data/train_size": len(train_dataset),
"data/val_size": len(val_dataset)
})
```
### 3. Use Descriptive Names
```python
# ✅ Good: Descriptive run names
wandb.init(
project="nlp-classification",
name="bert-base-lr0.001-bs32-epoch10"
)
# ❌ Bad: Generic names
wandb.init(project="nlp", name="run1")
```
### 4. Save Important Artifacts
```python
# Save final model
artifact = wandb.Artifact('final-model', type='model')
artifact.add_file('model.pth')
wandb.log_artifact(artifact)
# Save predictions for analysis
predictions_table = wandb.Table(
columns=["id", "input", "prediction", "ground_truth"],
data=predictions_data
)
wandb.log({"predictions": predictions_table})
```
### 5. Use Offline Mode for Unstable Connections
```python
import os
# Enable offline mode
os.environ["WANDB_MODE"] = "offline"
wandb.init(project="my-project")
# ... your code ...
# Sync later
# wandb sync <run_directory>
```
## Team Collaboration
### Share Runs
```python
# Runs are automatically shareable via URL
run = wandb.init(project="team-project")
print(f"Share this URL: {run.url}")
```
### Team Projects
- Create team account at wandb.ai
- Add team members
- Set project visibility (private/public)
- Use team-level artifacts and model registry
## Pricing
- **Free**: Unlimited public projects, 100GB storage
- **Academic**: Free for students/researchers
- **Teams**: $50/seat/month, private projects, unlimited storage
- **Enterprise**: Custom pricing, on-prem options
## Resources
- **Documentation**: https://docs.wandb.ai
- **GitHub**: https://github.com/wandb/wandb (10.5k+ stars)
- **Examples**: https://github.com/wandb/examples
- **Community**: https://wandb.ai/community
- **Discord**: https://wandb.me/discord
## See Also
- `references/sweeps.md` - Comprehensive hyperparameter optimization guide
- `references/artifacts.md` - Data and model versioning patterns
- `references/integrations.md` - Framework-specific examples

View file

@ -0,0 +1,584 @@
# Artifacts & Model Registry Guide
Complete guide to data versioning and model management with W&B Artifacts.
## Table of Contents
- What are Artifacts
- Creating Artifacts
- Using Artifacts
- Model Registry
- Versioning & Lineage
- Best Practices
## What are Artifacts
Artifacts are versioned datasets, models, or files tracked with lineage.
**Key Features:**
- Automatic versioning (v0, v1, v2...)
- Lineage tracking (which runs produced/used artifacts)
- Efficient storage (deduplication)
- Collaboration (team-wide access)
- Aliases (latest, best, production)
**Common Use Cases:**
- Dataset versioning
- Model checkpoints
- Preprocessed data
- Evaluation results
- Configuration files
## Creating Artifacts
### Basic Dataset Artifact
```python
import wandb
run = wandb.init(project="my-project")
# Create artifact
dataset = wandb.Artifact(
name='training-data',
type='dataset',
description='ImageNet training split with augmentations',
metadata={
'size': '1.2M images',
'format': 'JPEG',
'resolution': '224x224'
}
)
# Add files
dataset.add_file('data/train.csv') # Single file
dataset.add_dir('data/images') # Entire directory
dataset.add_reference('s3://bucket/data') # Cloud reference
# Log artifact
run.log_artifact(dataset)
wandb.finish()
```
### Model Artifact
```python
import torch
import wandb
run = wandb.init(project="my-project")
# Train model
model = train_model()
# Save model
torch.save(model.state_dict(), 'model.pth')
# Create model artifact
model_artifact = wandb.Artifact(
name='resnet50-classifier',
type='model',
description='ResNet50 trained on ImageNet',
metadata={
'architecture': 'ResNet50',
'accuracy': 0.95,
'loss': 0.15,
'epochs': 50,
'framework': 'PyTorch'
}
)
# Add model file
model_artifact.add_file('model.pth')
# Add config
model_artifact.add_file('config.yaml')
# Log with aliases
run.log_artifact(model_artifact, aliases=['latest', 'best'])
wandb.finish()
```
### Preprocessed Data Artifact
```python
import pandas as pd
import wandb
run = wandb.init(project="nlp-project")
# Preprocess data
df = pd.read_csv('raw_data.csv')
df_processed = preprocess(df)
df_processed.to_csv('processed_data.csv', index=False)
# Create artifact
processed_data = wandb.Artifact(
name='processed-text-data',
type='dataset',
metadata={
'rows': len(df_processed),
'columns': list(df_processed.columns),
'preprocessing_steps': ['lowercase', 'remove_stopwords', 'tokenize']
}
)
processed_data.add_file('processed_data.csv')
# Log artifact
run.log_artifact(processed_data)
```
## Using Artifacts
### Download and Use
```python
import wandb
run = wandb.init(project="my-project")
# Download artifact
artifact = run.use_artifact('training-data:latest')
artifact_dir = artifact.download()
# Use files
import pandas as pd
df = pd.read_csv(f'{artifact_dir}/train.csv')
# Train with artifact data
model = train_model(df)
```
### Use Specific Version
```python
# Use specific version
artifact_v2 = run.use_artifact('training-data:v2')
# Use alias
artifact_best = run.use_artifact('model:best')
artifact_prod = run.use_artifact('model:production')
# Use from another project
artifact = run.use_artifact('team/other-project/model:latest')
```
### Check Artifact Metadata
```python
artifact = run.use_artifact('training-data:latest')
# Access metadata
print(artifact.metadata)
print(f"Size: {artifact.metadata['size']}")
# Access version info
print(f"Version: {artifact.version}")
print(f"Created at: {artifact.created_at}")
print(f"Digest: {artifact.digest}")
```
## Model Registry
Link models to a central registry for governance and deployment.
### Create Model Registry
```python
# In W&B UI:
# 1. Go to "Registry" tab
# 2. Create new registry: "production-models"
# 3. Define stages: development, staging, production
```
### Link Model to Registry
```python
import wandb
run = wandb.init(project="training")
# Create model artifact
model_artifact = wandb.Artifact(
name='sentiment-classifier',
type='model',
metadata={'accuracy': 0.94, 'f1': 0.92}
)
model_artifact.add_file('model.pth')
# Log artifact
run.log_artifact(model_artifact)
# Link to registry
run.link_artifact(
model_artifact,
'model-registry/production-models',
aliases=['staging'] # Deploy to staging
)
wandb.finish()
```
### Promote Model in Registry
```python
# Retrieve model from registry
api = wandb.Api()
artifact = api.artifact('model-registry/production-models/sentiment-classifier:staging')
# Promote to production
artifact.link('model-registry/production-models', aliases=['production'])
# Demote from production
artifact.aliases = ['archived']
artifact.save()
```
### Use Model from Registry
```python
import wandb
run = wandb.init()
# Download production model
model_artifact = run.use_artifact(
'model-registry/production-models/sentiment-classifier:production'
)
model_dir = model_artifact.download()
# Load and use
import torch
model = torch.load(f'{model_dir}/model.pth')
model.eval()
```
## Versioning & Lineage
### Automatic Versioning
```python
# First log: creates v0
run1 = wandb.init(project="my-project")
dataset_v0 = wandb.Artifact('my-dataset', type='dataset')
dataset_v0.add_file('data_v1.csv')
run1.log_artifact(dataset_v0)
# Second log with same name: creates v1
run2 = wandb.init(project="my-project")
dataset_v1 = wandb.Artifact('my-dataset', type='dataset')
dataset_v1.add_file('data_v2.csv') # Different content
run2.log_artifact(dataset_v1)
# Third log with SAME content as v1: references v1 (no new version)
run3 = wandb.init(project="my-project")
dataset_v1_again = wandb.Artifact('my-dataset', type='dataset')
dataset_v1_again.add_file('data_v2.csv') # Same content as v1
run3.log_artifact(dataset_v1_again) # Still v1, no v2 created
```
### Track Lineage
```python
# Training run
run = wandb.init(project="my-project")
# Use dataset (input)
dataset = run.use_artifact('training-data:v3')
data = load_data(dataset.download())
# Train model
model = train(data)
# Save model (output)
model_artifact = wandb.Artifact('trained-model', type='model')
torch.save(model.state_dict(), 'model.pth')
model_artifact.add_file('model.pth')
run.log_artifact(model_artifact)
# Lineage automatically tracked:
# training-data:v3 --> [run] --> trained-model:v0
```
### View Lineage Graph
```python
# In W&B UI:
# Artifacts → Select artifact → Lineage tab
# Shows:
# - Which runs produced this artifact
# - Which runs used this artifact
# - Parent/child artifacts
```
## Artifact Types
### Dataset Artifacts
```python
# Raw data
raw_data = wandb.Artifact('raw-data', type='dataset')
raw_data.add_dir('raw/')
# Processed data
processed_data = wandb.Artifact('processed-data', type='dataset')
processed_data.add_dir('processed/')
# Train/val/test splits
train_split = wandb.Artifact('train-split', type='dataset')
train_split.add_file('train.csv')
val_split = wandb.Artifact('val-split', type='dataset')
val_split.add_file('val.csv')
```
### Model Artifacts
```python
# Checkpoint during training
checkpoint = wandb.Artifact('checkpoint-epoch-10', type='model')
checkpoint.add_file('checkpoint_epoch_10.pth')
# Final model
final_model = wandb.Artifact('final-model', type='model')
final_model.add_file('model.pth')
final_model.add_file('tokenizer.json')
# Quantized model
quantized = wandb.Artifact('quantized-model', type='model')
quantized.add_file('model_int8.onnx')
```
### Result Artifacts
```python
# Predictions
predictions = wandb.Artifact('test-predictions', type='predictions')
predictions.add_file('predictions.csv')
# Evaluation metrics
eval_results = wandb.Artifact('evaluation', type='evaluation')
eval_results.add_file('metrics.json')
eval_results.add_file('confusion_matrix.png')
```
## Advanced Patterns
### Incremental Artifacts
Add files incrementally without re-uploading.
```python
run = wandb.init(project="my-project")
# Create artifact
dataset = wandb.Artifact('incremental-dataset', type='dataset')
# Add files incrementally
for i in range(100):
filename = f'batch_{i}.csv'
process_batch(i, filename)
dataset.add_file(filename)
# Log progress
if (i + 1) % 10 == 0:
print(f"Added {i + 1}/100 batches")
# Log complete artifact
run.log_artifact(dataset)
```
### Artifact Tables
Track structured data with W&B Tables.
```python
import wandb
run = wandb.init(project="my-project")
# Create table
table = wandb.Table(columns=["id", "image", "label", "prediction"])
for idx, (img, label, pred) in enumerate(zip(images, labels, predictions)):
table.add_data(
idx,
wandb.Image(img),
label,
pred
)
# Log as artifact
artifact = wandb.Artifact('predictions-table', type='predictions')
artifact.add(table, "predictions")
run.log_artifact(artifact)
```
### Artifact References
Reference external data without copying.
```python
# S3 reference
dataset = wandb.Artifact('s3-dataset', type='dataset')
dataset.add_reference('s3://my-bucket/data/', name='train')
dataset.add_reference('s3://my-bucket/labels/', name='labels')
# GCS reference
dataset.add_reference('gs://my-bucket/data/')
# HTTP reference
dataset.add_reference('https://example.com/data.zip')
# Local filesystem reference (for shared storage)
dataset.add_reference('file:///mnt/shared/data')
```
## Collaboration Patterns
### Team Dataset Sharing
```python
# Data engineer creates dataset
run = wandb.init(project="data-eng", entity="my-team")
dataset = wandb.Artifact('shared-dataset', type='dataset')
dataset.add_dir('data/')
run.log_artifact(dataset, aliases=['latest', 'production'])
# ML engineer uses dataset
run = wandb.init(project="ml-training", entity="my-team")
dataset = run.use_artifact('my-team/data-eng/shared-dataset:production')
data = load_data(dataset.download())
```
### Model Handoff
```python
# Training team
train_run = wandb.init(project="model-training", entity="ml-team")
model = train_model()
model_artifact = wandb.Artifact('nlp-model', type='model')
model_artifact.add_file('model.pth')
train_run.log_artifact(model_artifact)
train_run.link_artifact(model_artifact, 'model-registry/nlp-models', aliases=['candidate'])
# Evaluation team
eval_run = wandb.init(project="model-eval", entity="ml-team")
model_artifact = eval_run.use_artifact('model-registry/nlp-models/nlp-model:candidate')
metrics = evaluate_model(model_artifact)
if metrics['f1'] > 0.9:
# Promote to production
model_artifact.link('model-registry/nlp-models', aliases=['production'])
```
## Best Practices
### 1. Use Descriptive Names
```python
# ✅ Good: Descriptive names
wandb.Artifact('imagenet-train-augmented-v2', type='dataset')
wandb.Artifact('bert-base-sentiment-finetuned', type='model')
# ❌ Bad: Generic names
wandb.Artifact('dataset1', type='dataset')
wandb.Artifact('model', type='model')
```
### 2. Add Comprehensive Metadata
```python
model_artifact = wandb.Artifact(
'production-model',
type='model',
description='ResNet50 classifier for product categorization',
metadata={
# Model info
'architecture': 'ResNet50',
'framework': 'PyTorch 2.0',
'pretrained': True,
# Performance
'accuracy': 0.95,
'f1_score': 0.93,
'inference_time_ms': 15,
# Training
'epochs': 50,
'dataset': 'imagenet',
'num_samples': 1200000,
# Business context
'use_case': 'e-commerce product classification',
'owner': 'ml-team@company.com',
'approved_by': 'data-science-lead'
}
)
```
### 3. Use Aliases for Deployment Stages
```python
# Development
run.log_artifact(model, aliases=['dev', 'latest'])
# Staging
run.log_artifact(model, aliases=['staging'])
# Production
run.log_artifact(model, aliases=['production', 'v1.2.0'])
# Archive old versions
old_artifact = api.artifact('model:production')
old_artifact.aliases = ['archived-v1.1.0']
old_artifact.save()
```
### 4. Track Data Lineage
```python
def create_training_pipeline():
run = wandb.init(project="pipeline")
# 1. Load raw data
raw_data = run.use_artifact('raw-data:latest')
# 2. Preprocess
processed = preprocess(raw_data)
processed_artifact = wandb.Artifact('processed-data', type='dataset')
processed_artifact.add_file('processed.csv')
run.log_artifact(processed_artifact)
# 3. Train model
model = train(processed)
model_artifact = wandb.Artifact('trained-model', type='model')
model_artifact.add_file('model.pth')
run.log_artifact(model_artifact)
# Lineage: raw-data → processed-data → trained-model
```
### 5. Efficient Storage
```python
# ✅ Good: Reference large files
large_dataset = wandb.Artifact('large-dataset', type='dataset')
large_dataset.add_reference('s3://bucket/huge-file.tar.gz')
# ❌ Bad: Upload giant files
# large_dataset.add_file('huge-file.tar.gz') # Don't do this
# ✅ Good: Upload only metadata
metadata_artifact = wandb.Artifact('dataset-metadata', type='dataset')
metadata_artifact.add_file('metadata.json') # Small file
```
## Resources
- **Artifacts Documentation**: https://docs.wandb.ai/guides/artifacts
- **Model Registry**: https://docs.wandb.ai/guides/model-registry
- **Best Practices**: https://wandb.ai/site/articles/versioning-data-and-models-in-ml

View file

@ -0,0 +1,700 @@
# Framework Integrations Guide
Complete guide to integrating W&B with popular ML frameworks.
## Table of Contents
- HuggingFace Transformers
- PyTorch Lightning
- Keras/TensorFlow
- Fast.ai
- XGBoost/LightGBM
- PyTorch Native
- Custom Integrations
## HuggingFace Transformers
### Automatic Integration
```python
from transformers import Trainer, TrainingArguments
import wandb
# Initialize W&B
wandb.init(project="hf-transformers", name="bert-finetuning")
# Training arguments with W&B
training_args = TrainingArguments(
output_dir="./results",
report_to="wandb", # Enable W&B logging
run_name="bert-base-finetuning",
# Training params
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
learning_rate=2e-5,
# Logging
logging_dir="./logs",
logging_steps=100,
logging_first_step=True,
# Evaluation
evaluation_strategy="steps",
eval_steps=500,
save_steps=500,
# Other
load_best_model_at_end=True,
metric_for_best_model="eval_accuracy"
)
# Trainer automatically logs to W&B
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics
)
# Train (metrics logged automatically)
trainer.train()
# Finish W&B run
wandb.finish()
```
### Custom Logging
```python
from transformers import Trainer, TrainingArguments
from transformers.integrations import WandbCallback
import wandb
class CustomWandbCallback(WandbCallback):
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
super().on_evaluate(args, state, control, metrics, **kwargs)
# Log custom metrics
wandb.log({
"custom/eval_score": metrics["eval_accuracy"] * 100,
"custom/epoch": state.epoch
})
# Use custom callback
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=[CustomWandbCallback()]
)
```
### Log Model to Registry
```python
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir="./results",
report_to="wandb",
load_best_model_at_end=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()
# Save final model as artifact
model_artifact = wandb.Artifact(
'hf-bert-model',
type='model',
description='BERT finetuned on sentiment analysis'
)
# Save model files
trainer.save_model("./final_model")
model_artifact.add_dir("./final_model")
# Log artifact
wandb.log_artifact(model_artifact, aliases=['best', 'production'])
wandb.finish()
```
## PyTorch Lightning
### Basic Integration
```python
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
# Create W&B logger
wandb_logger = WandbLogger(
project="lightning-demo",
name="resnet50-training",
log_model=True, # Log model checkpoints as artifacts
save_code=True # Save code as artifact
)
# Lightning module
class LitModel(pl.LightningModule):
def __init__(self, learning_rate=0.001):
super().__init__()
self.save_hyperparameters()
self.model = create_model()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
# Log metrics (automatically sent to W&B)
self.log('train/loss', loss, on_step=True, on_epoch=True)
self.log('train/accuracy', accuracy(y_hat, y), on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log('val/loss', loss, on_step=False, on_epoch=True)
self.log('val/accuracy', accuracy(y_hat, y), on_epoch=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
# Trainer with W&B logger
trainer = pl.Trainer(
logger=wandb_logger,
max_epochs=10,
accelerator="gpu",
devices=1
)
# Train (metrics logged automatically)
trainer.fit(model, datamodule=dm)
# Finish W&B run
wandb.finish()
```
### Log Media
```python
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
# Log images (first batch only)
if batch_idx == 0:
self.logger.experiment.log({
"examples": [wandb.Image(img) for img in x[:8]]
})
return loss
def on_validation_epoch_end(self):
# Log confusion matrix
cm = compute_confusion_matrix(self.all_preds, self.all_targets)
self.logger.experiment.log({
"confusion_matrix": wandb.plot.confusion_matrix(
probs=None,
y_true=self.all_targets,
preds=self.all_preds,
class_names=self.class_names
)
})
```
### Hyperparameter Sweeps
```python
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
# Define sweep
sweep_config = {
'method': 'bayes',
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
'parameters': {
'learning_rate': {'min': 1e-5, 'max': 1e-2, 'distribution': 'log_uniform'},
'batch_size': {'values': [16, 32, 64]},
'hidden_size': {'values': [128, 256, 512]}
}
}
sweep_id = wandb.sweep(sweep_config, project="lightning-sweeps")
def train():
# Initialize W&B
run = wandb.init()
# Get hyperparameters
config = wandb.config
# Create logger
wandb_logger = WandbLogger()
# Create model with sweep params
model = LitModel(
learning_rate=config.learning_rate,
hidden_size=config.hidden_size
)
# Create datamodule with sweep batch size
dm = DataModule(batch_size=config.batch_size)
# Train
trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
trainer.fit(model, dm)
# Run sweep
wandb.agent(sweep_id, function=train, count=30)
```
## Keras/TensorFlow
### With Callback
```python
import tensorflow as tf
from wandb.keras import WandbCallback
import wandb
# Initialize W&B
wandb.init(
project="keras-demo",
config={
"learning_rate": 0.001,
"epochs": 10,
"batch_size": 32
}
)
config = wandb.config
# Build model
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(config.learning_rate),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Train with W&B callback
history = model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=config.epochs,
batch_size=config.batch_size,
callbacks=[
WandbCallback(
log_weights=True, # Log model weights
log_gradients=True, # Log gradients
training_data=(x_train, y_train),
validation_data=(x_val, y_val),
labels=class_names
)
]
)
# Save model as artifact
model.save('model.h5')
artifact = wandb.Artifact('keras-model', type='model')
artifact.add_file('model.h5')
wandb.log_artifact(artifact)
wandb.finish()
```
### Custom Training Loop
```python
import tensorflow as tf
import wandb
wandb.init(project="tf-custom-loop")
# Model, optimizer, loss
model = create_model()
optimizer = tf.keras.optimizers.Adam(1e-3)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
# Metrics
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
predictions = model(x, training=True)
loss = loss_fn(y, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(y, predictions)
# Training loop
for epoch in range(EPOCHS):
train_loss.reset_states()
train_accuracy.reset_states()
for step, (x, y) in enumerate(train_dataset):
train_step(x, y)
# Log every 100 steps
if step % 100 == 0:
wandb.log({
'train/loss': train_loss.result().numpy(),
'train/accuracy': train_accuracy.result().numpy(),
'epoch': epoch,
'step': step
})
# Log epoch metrics
wandb.log({
'epoch/train_loss': train_loss.result().numpy(),
'epoch/train_accuracy': train_accuracy.result().numpy(),
'epoch': epoch
})
wandb.finish()
```
## Fast.ai
### With Callback
```python
from fastai.vision.all import *
from fastai.callback.wandb import *
import wandb
# Initialize W&B
wandb.init(project="fastai-demo")
# Create data loaders
dls = ImageDataLoaders.from_folder(
path,
train='train',
valid='valid',
bs=64
)
# Create learner with W&B callback
learn = vision_learner(
dls,
resnet34,
metrics=accuracy,
cbs=WandbCallback(
log_preds=True, # Log predictions
log_model=True, # Log model as artifact
log_dataset=True # Log dataset as artifact
)
)
# Train (metrics logged automatically)
learn.fine_tune(5)
wandb.finish()
```
## XGBoost/LightGBM
### XGBoost
```python
import xgboost as xgb
import wandb
# Initialize W&B
run = wandb.init(project="xgboost-demo", config={
"max_depth": 6,
"learning_rate": 0.1,
"n_estimators": 100
})
config = wandb.config
# Create DMatrix
dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)
# XGBoost params
params = {
'max_depth': config.max_depth,
'learning_rate': config.learning_rate,
'objective': 'binary:logistic',
'eval_metric': ['logloss', 'auc']
}
# Custom callback for W&B
def wandb_callback(env):
"""Log XGBoost metrics to W&B."""
for metric_name, metric_value in env.evaluation_result_list:
wandb.log({
f"{metric_name}": metric_value,
"iteration": env.iteration
})
# Train with callback
model = xgb.train(
params,
dtrain,
num_boost_round=config.n_estimators,
evals=[(dtrain, 'train'), (dval, 'val')],
callbacks=[wandb_callback],
verbose_eval=10
)
# Save model
model.save_model('xgboost_model.json')
artifact = wandb.Artifact('xgboost-model', type='model')
artifact.add_file('xgboost_model.json')
wandb.log_artifact(artifact)
wandb.finish()
```
### LightGBM
```python
import lightgbm as lgb
import wandb
run = wandb.init(project="lgbm-demo")
# Create datasets
train_data = lgb.Dataset(X_train, label=y_train)
val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
# Parameters
params = {
'objective': 'binary',
'metric': ['binary_logloss', 'auc'],
'learning_rate': 0.1,
'num_leaves': 31
}
# Custom callback
def log_to_wandb(env):
"""Log LightGBM metrics to W&B."""
for entry in env.evaluation_result_list:
dataset_name, metric_name, metric_value, _ = entry
wandb.log({
f"{dataset_name}/{metric_name}": metric_value,
"iteration": env.iteration
})
# Train
model = lgb.train(
params,
train_data,
num_boost_round=100,
valid_sets=[train_data, val_data],
valid_names=['train', 'val'],
callbacks=[log_to_wandb]
)
# Save model
model.save_model('lgbm_model.txt')
artifact = wandb.Artifact('lgbm-model', type='model')
artifact.add_file('lgbm_model.txt')
wandb.log_artifact(artifact)
wandb.finish()
```
## PyTorch Native
### Training Loop Integration
```python
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
# Initialize W&B
wandb.init(project="pytorch-native", config={
"learning_rate": 0.001,
"epochs": 10,
"batch_size": 32
})
config = wandb.config
# Model, loss, optimizer
model = create_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
# Watch model (logs gradients and parameters)
wandb.watch(model, criterion, log="all", log_freq=100)
# Training loop
for epoch in range(config.epochs):
model.train()
train_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
# Forward pass
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# Backward pass
loss.backward()
optimizer.step()
# Track metrics
train_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
# Log every 100 batches
if batch_idx % 100 == 0:
wandb.log({
'train/loss': loss.item(),
'train/batch_accuracy': 100. * correct / total,
'epoch': epoch,
'batch': batch_idx
})
# Validation
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
val_loss += loss.item()
_, predicted = output.max(1)
val_total += target.size(0)
val_correct += predicted.eq(target).sum().item()
# Log epoch metrics
wandb.log({
'epoch/train_loss': train_loss / len(train_loader),
'epoch/train_accuracy': 100. * correct / total,
'epoch/val_loss': val_loss / len(val_loader),
'epoch/val_accuracy': 100. * val_correct / val_total,
'epoch': epoch
})
# Save final model
torch.save(model.state_dict(), 'model.pth')
artifact = wandb.Artifact('final-model', type='model')
artifact.add_file('model.pth')
wandb.log_artifact(artifact)
wandb.finish()
```
## Custom Integrations
### Generic Framework Integration
```python
import wandb
class WandbIntegration:
"""Generic W&B integration wrapper."""
def __init__(self, project, config):
self.run = wandb.init(project=project, config=config)
self.config = wandb.config
self.step = 0
def log_metrics(self, metrics, step=None):
"""Log training metrics."""
if step is None:
step = self.step
self.step += 1
wandb.log(metrics, step=step)
def log_images(self, images, caption=""):
"""Log images."""
wandb.log({
caption: [wandb.Image(img) for img in images]
})
def log_table(self, data, columns):
"""Log tabular data."""
table = wandb.Table(columns=columns, data=data)
wandb.log({"table": table})
def save_model(self, model_path, metadata=None):
"""Save model as artifact."""
artifact = wandb.Artifact(
'model',
type='model',
metadata=metadata or {}
)
artifact.add_file(model_path)
self.run.log_artifact(artifact)
def finish(self):
"""Finish W&B run."""
wandb.finish()
# Usage
wb = WandbIntegration(project="my-project", config={"lr": 0.001})
# Training loop
for epoch in range(10):
# Your training code
loss, accuracy = train_epoch()
# Log metrics
wb.log_metrics({
'train/loss': loss,
'train/accuracy': accuracy
})
# Save model
wb.save_model('model.pth', metadata={'accuracy': 0.95})
wb.finish()
```
## Resources
- **Integrations Guide**: https://docs.wandb.ai/guides/integrations
- **HuggingFace**: https://docs.wandb.ai/guides/integrations/huggingface
- **PyTorch Lightning**: https://docs.wandb.ai/guides/integrations/lightning
- **Keras**: https://docs.wandb.ai/guides/integrations/keras
- **Examples**: https://github.com/wandb/examples

View file

@ -0,0 +1,847 @@
# Comprehensive Hyperparameter Sweeps Guide
Complete guide to hyperparameter optimization with W&B Sweeps.
## Table of Contents
- Sweep Configuration
- Search Strategies
- Parameter Distributions
- Early Termination
- Parallel Execution
- Advanced Patterns
- Real-World Examples
## Sweep Configuration
### Basic Sweep Config
```python
sweep_config = {
'method': 'bayes', # Search strategy
'metric': {
'name': 'val/accuracy',
'goal': 'maximize' # or 'minimize'
},
'parameters': {
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
'batch_size': {
'values': [16, 32, 64, 128]
}
}
}
# Initialize sweep
sweep_id = wandb.sweep(sweep_config, project="my-project")
```
### Complete Config Example
```python
sweep_config = {
# Required: Search method
'method': 'bayes',
# Required: Optimization metric
'metric': {
'name': 'val/f1_score',
'goal': 'maximize'
},
# Required: Parameters to search
'parameters': {
# Continuous parameter
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
# Discrete values
'batch_size': {
'values': [16, 32, 64, 128]
},
# Categorical
'optimizer': {
'values': ['adam', 'sgd', 'rmsprop', 'adamw']
},
# Uniform distribution
'dropout': {
'distribution': 'uniform',
'min': 0.1,
'max': 0.5
},
# Integer range
'num_layers': {
'distribution': 'int_uniform',
'min': 2,
'max': 10
},
# Fixed value (constant across runs)
'epochs': {
'value': 50
}
},
# Optional: Early termination
'early_terminate': {
'type': 'hyperband',
'min_iter': 5,
's': 2,
'eta': 3,
'max_iter': 27
}
}
```
## Search Strategies
### 1. Grid Search
Exhaustively search all combinations.
```python
sweep_config = {
'method': 'grid',
'parameters': {
'learning_rate': {
'values': [0.001, 0.01, 0.1]
},
'batch_size': {
'values': [16, 32, 64]
},
'optimizer': {
'values': ['adam', 'sgd']
}
}
}
# Total runs: 3 × 3 × 2 = 18 runs
```
**Pros:**
- Comprehensive search
- Reproducible results
- No randomness
**Cons:**
- Exponential growth with parameters
- Inefficient for continuous parameters
- Not scalable beyond 3-4 parameters
**When to use:**
- Few parameters (< 4)
- All discrete values
- Need complete coverage
### 2. Random Search
Randomly sample parameter combinations.
```python
sweep_config = {
'method': 'random',
'parameters': {
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
'batch_size': {
'values': [16, 32, 64, 128, 256]
},
'dropout': {
'distribution': 'uniform',
'min': 0.0,
'max': 0.5
},
'num_layers': {
'distribution': 'int_uniform',
'min': 2,
'max': 8
}
}
}
# Run 100 random trials
wandb.agent(sweep_id, function=train, count=100)
```
**Pros:**
- Scales to many parameters
- Can run indefinitely
- Often finds good solutions quickly
**Cons:**
- No learning from previous runs
- May miss optimal region
- Results vary with random seed
**When to use:**
- Many parameters (> 4)
- Quick exploration
- Limited budget
### 3. Bayesian Optimization (Recommended)
Learn from previous trials to sample promising regions.
```python
sweep_config = {
'method': 'bayes',
'metric': {
'name': 'val/loss',
'goal': 'minimize'
},
'parameters': {
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
'weight_decay': {
'distribution': 'log_uniform',
'min': 1e-6,
'max': 1e-2
},
'dropout': {
'distribution': 'uniform',
'min': 0.1,
'max': 0.5
},
'num_layers': {
'values': [2, 3, 4, 5, 6]
}
}
}
```
**Pros:**
- Most sample-efficient
- Learns from past trials
- Focuses on promising regions
**Cons:**
- Initial random exploration phase
- May get stuck in local optima
- Slower per iteration
**When to use:**
- Expensive training runs
- Need best performance
- Limited compute budget
## Parameter Distributions
### Continuous Distributions
```python
# Log-uniform: Good for learning rates, regularization
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-6,
'max': 1e-1
}
# Uniform: Good for dropout, momentum
'dropout': {
'distribution': 'uniform',
'min': 0.0,
'max': 0.5
}
# Normal distribution
'parameter': {
'distribution': 'normal',
'mu': 0.5,
'sigma': 0.1
}
# Log-normal distribution
'parameter': {
'distribution': 'log_normal',
'mu': 0.0,
'sigma': 1.0
}
```
### Discrete Distributions
```python
# Fixed values
'batch_size': {
'values': [16, 32, 64, 128, 256]
}
# Integer uniform
'num_layers': {
'distribution': 'int_uniform',
'min': 2,
'max': 10
}
# Quantized uniform (step size)
'layer_size': {
'distribution': 'q_uniform',
'min': 32,
'max': 512,
'q': 32 # Step by 32: 32, 64, 96, 128...
}
# Quantized log-uniform
'hidden_size': {
'distribution': 'q_log_uniform',
'min': 32,
'max': 1024,
'q': 32
}
```
### Categorical Parameters
```python
# Optimizers
'optimizer': {
'values': ['adam', 'sgd', 'rmsprop', 'adamw']
}
# Model architectures
'model': {
'values': ['resnet18', 'resnet34', 'resnet50', 'efficientnet_b0']
}
# Activation functions
'activation': {
'values': ['relu', 'gelu', 'silu', 'leaky_relu']
}
```
## Early Termination
Stop underperforming runs early to save compute.
### Hyperband
```python
sweep_config = {
'method': 'bayes',
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
'parameters': {...},
# Hyperband early termination
'early_terminate': {
'type': 'hyperband',
'min_iter': 3, # Minimum iterations before termination
's': 2, # Bracket count
'eta': 3, # Downsampling rate
'max_iter': 27 # Maximum iterations
}
}
```
**How it works:**
- Runs trials in brackets
- Keeps top 1/eta performers each round
- Eliminates bottom performers early
### Custom Termination
```python
def train():
run = wandb.init()
for epoch in range(MAX_EPOCHS):
loss = train_epoch()
val_acc = validate()
wandb.log({'val/accuracy': val_acc, 'epoch': epoch})
# Custom early stopping
if epoch > 5 and val_acc < 0.5:
print("Early stop: Poor performance")
break
if epoch > 10 and val_acc > best_acc - 0.01:
print("Early stop: No improvement")
break
```
## Training Function
### Basic Template
```python
def train():
# Initialize W&B run
run = wandb.init()
# Get hyperparameters
config = wandb.config
# Build model with config
model = build_model(
hidden_size=config.hidden_size,
num_layers=config.num_layers,
dropout=config.dropout
)
# Create optimizer
optimizer = create_optimizer(
model.parameters(),
name=config.optimizer,
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# Training loop
for epoch in range(config.epochs):
# Train
train_loss, train_acc = train_epoch(
model, optimizer, train_loader, config.batch_size
)
# Validate
val_loss, val_acc = validate(model, val_loader)
# Log metrics
wandb.log({
'train/loss': train_loss,
'train/accuracy': train_acc,
'val/loss': val_loss,
'val/accuracy': val_acc,
'epoch': epoch
})
# Log final model
torch.save(model.state_dict(), 'model.pth')
wandb.save('model.pth')
# Finish run
wandb.finish()
```
### With PyTorch
```python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import wandb
def train():
run = wandb.init()
config = wandb.config
# Data
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True
)
# Model
model = ResNet(
num_classes=config.num_classes,
dropout=config.dropout
).to(device)
# Optimizer
if config.optimizer == 'adam':
optimizer = torch.optim.Adam(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
elif config.optimizer == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(),
lr=config.learning_rate,
momentum=config.momentum,
weight_decay=config.weight_decay
)
# Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=config.epochs
)
# Training
for epoch in range(config.epochs):
model.train()
train_loss = 0.0
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = nn.CrossEntropyLoss()(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation
model.eval()
val_loss, val_acc = validate(model, val_loader)
# Step scheduler
scheduler.step()
# Log
wandb.log({
'train/loss': train_loss / len(train_loader),
'val/loss': val_loss,
'val/accuracy': val_acc,
'learning_rate': scheduler.get_last_lr()[0],
'epoch': epoch
})
```
## Parallel Execution
### Multiple Agents
Run sweep agents in parallel to speed up search.
```python
# Initialize sweep once
sweep_id = wandb.sweep(sweep_config, project="my-project")
# Run multiple agents in parallel
# Agent 1 (Terminal 1)
wandb.agent(sweep_id, function=train, count=20)
# Agent 2 (Terminal 2)
wandb.agent(sweep_id, function=train, count=20)
# Agent 3 (Terminal 3)
wandb.agent(sweep_id, function=train, count=20)
# Total: 60 runs across 3 agents
```
### Multi-GPU Execution
```python
import os
def train():
# Get available GPU
gpu_id = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
run = wandb.init()
config = wandb.config
# Train on specific GPU
device = torch.device(f'cuda:{gpu_id}')
model = model.to(device)
# ... rest of training ...
# Run agents on different GPUs
# Terminal 1
# CUDA_VISIBLE_DEVICES=0 wandb agent sweep_id
# Terminal 2
# CUDA_VISIBLE_DEVICES=1 wandb agent sweep_id
# Terminal 3
# CUDA_VISIBLE_DEVICES=2 wandb agent sweep_id
```
## Advanced Patterns
### Nested Parameters
```python
sweep_config = {
'method': 'bayes',
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
'parameters': {
'model': {
'parameters': {
'type': {
'values': ['resnet', 'efficientnet']
},
'size': {
'values': ['small', 'medium', 'large']
}
}
},
'optimizer': {
'parameters': {
'type': {
'values': ['adam', 'sgd']
},
'lr': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
}
}
}
}
}
# Access nested config
def train():
run = wandb.init()
model_type = wandb.config.model.type
model_size = wandb.config.model.size
opt_type = wandb.config.optimizer.type
lr = wandb.config.optimizer.lr
```
### Conditional Parameters
```python
sweep_config = {
'method': 'bayes',
'parameters': {
'optimizer': {
'values': ['adam', 'sgd']
},
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
# Only used if optimizer == 'sgd'
'momentum': {
'distribution': 'uniform',
'min': 0.5,
'max': 0.99
}
}
}
def train():
run = wandb.init()
config = wandb.config
if config.optimizer == 'adam':
optimizer = torch.optim.Adam(
model.parameters(),
lr=config.learning_rate
)
elif config.optimizer == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(),
lr=config.learning_rate,
momentum=config.momentum # Conditional parameter
)
```
## Real-World Examples
### Image Classification
```python
sweep_config = {
'method': 'bayes',
'metric': {
'name': 'val/top1_accuracy',
'goal': 'maximize'
},
'parameters': {
# Model
'architecture': {
'values': ['resnet50', 'resnet101', 'efficientnet_b0', 'efficientnet_b3']
},
'pretrained': {
'values': [True, False]
},
# Training
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-2
},
'batch_size': {
'values': [16, 32, 64, 128]
},
'optimizer': {
'values': ['adam', 'sgd', 'adamw']
},
'weight_decay': {
'distribution': 'log_uniform',
'min': 1e-6,
'max': 1e-2
},
# Regularization
'dropout': {
'distribution': 'uniform',
'min': 0.0,
'max': 0.5
},
'label_smoothing': {
'distribution': 'uniform',
'min': 0.0,
'max': 0.2
},
# Data augmentation
'mixup_alpha': {
'distribution': 'uniform',
'min': 0.0,
'max': 1.0
},
'cutmix_alpha': {
'distribution': 'uniform',
'min': 0.0,
'max': 1.0
}
},
'early_terminate': {
'type': 'hyperband',
'min_iter': 5
}
}
```
### NLP Fine-Tuning
```python
sweep_config = {
'method': 'bayes',
'metric': {'name': 'eval/f1', 'goal': 'maximize'},
'parameters': {
# Model
'model_name': {
'values': ['bert-base-uncased', 'roberta-base', 'distilbert-base-uncased']
},
# Training
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-6,
'max': 1e-4
},
'per_device_train_batch_size': {
'values': [8, 16, 32]
},
'num_train_epochs': {
'values': [3, 4, 5]
},
'warmup_ratio': {
'distribution': 'uniform',
'min': 0.0,
'max': 0.1
},
'weight_decay': {
'distribution': 'log_uniform',
'min': 1e-4,
'max': 1e-1
},
# Optimizer
'adam_beta1': {
'distribution': 'uniform',
'min': 0.8,
'max': 0.95
},
'adam_beta2': {
'distribution': 'uniform',
'min': 0.95,
'max': 0.999
}
}
}
```
## Best Practices
### 1. Start Small
```python
# Initial exploration: Random search, 20 runs
sweep_config_v1 = {
'method': 'random',
'parameters': {...}
}
wandb.agent(sweep_id_v1, train, count=20)
# Refined search: Bayes, narrow ranges
sweep_config_v2 = {
'method': 'bayes',
'parameters': {
'learning_rate': {
'min': 5e-5, # Narrowed from 1e-6 to 1e-4
'max': 1e-4
}
}
}
```
### 2. Use Log Scales
```python
# ✅ Good: Log scale for learning rate
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-6,
'max': 1e-2
}
# ❌ Bad: Linear scale
'learning_rate': {
'distribution': 'uniform',
'min': 0.000001,
'max': 0.01
}
```
### 3. Set Reasonable Ranges
```python
# Base ranges on prior knowledge
'learning_rate': {'min': 1e-5, 'max': 1e-3}, # Typical for Adam
'batch_size': {'values': [16, 32, 64]}, # GPU memory limits
'dropout': {'min': 0.1, 'max': 0.5} # Too high hurts training
```
### 4. Monitor Resource Usage
```python
def train():
run = wandb.init()
# Log system metrics
wandb.log({
'system/gpu_memory_allocated': torch.cuda.memory_allocated(),
'system/gpu_memory_reserved': torch.cuda.memory_reserved()
})
```
### 5. Save Best Models
```python
def train():
run = wandb.init()
best_acc = 0.0
for epoch in range(config.epochs):
val_acc = validate(model)
if val_acc > best_acc:
best_acc = val_acc
# Save best checkpoint
torch.save(model.state_dict(), 'best_model.pth')
wandb.save('best_model.pth')
```
## Resources
- **Sweeps Documentation**: https://docs.wandb.ai/guides/sweeps
- **Configuration Reference**: https://docs.wandb.ai/guides/sweeps/configuration
- **Examples**: https://github.com/wandb/examples/tree/master/examples/wandb-sweeps

View file

@ -0,0 +1,80 @@
---
name: huggingface-hub
description: Hugging Face Hub CLI (hf) — search, download, and upload models and datasets, manage repos, query datasets with SQL, deploy inference endpoints, manage Spaces and buckets.
version: 1.0.0
author: Hugging Face
license: MIT
tags: [huggingface, hf, models, datasets, hub, mlops]
---
# Hugging Face CLI (`hf`) Reference Guide
The `hf` command is the modern command-line interface for interacting with the Hugging Face Hub, providing tools to manage repositories, models, datasets, and Spaces.
> **IMPORTANT:** The `hf` command replaces the now deprecated `huggingface-cli` command.
## Quick Start
* **Installation:** `curl -LsSf https://hf.co/cli/install.sh | bash -s`
* **Help:** Use `hf --help` to view all available functions and real-world examples.
* **Authentication:** Recommended via `HF_TOKEN` environment variable or the `--token` flag.
---
## Core Commands
### General Operations
* `hf download REPO_ID`: Download files from the Hub.
* `hf upload REPO_ID`: Upload files/folders (recommended for single-commit).
* `hf upload-large-folder REPO_ID LOCAL_PATH`: Recommended for resumable uploads of large directories.
* `hf sync`: Sync files between a local directory and a bucket.
* `hf env` / `hf version`: View environment and version details.
### Authentication (`hf auth`)
* `login` / `logout`: Manage sessions using tokens from [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
* `list` / `switch`: Manage and toggle between multiple stored access tokens.
* `whoami`: Identify the currently logged-in account.
### Repository Management (`hf repos`)
* `create` / `delete`: Create or permanently remove repositories.
* `duplicate`: Clone a model, dataset, or Space to a new ID.
* `move`: Transfer a repository between namespaces.
* `branch` / `tag`: Manage Git-like references.
* `delete-files`: Remove specific files using patterns.
---
## Specialized Hub Interactions
### Datasets & Models
* **Datasets:** `hf datasets list`, `info`, and `parquet` (list parquet URLs).
* **SQL Queries:** `hf datasets sql SQL` — Execute raw SQL via DuckDB against dataset parquet URLs.
* **Models:** `hf models list` and `info`.
* **Papers:** `hf papers list` — View daily papers.
### Discussions & Pull Requests (`hf discussions`)
* Manage the lifecycle of Hub contributions: `list`, `create`, `info`, `comment`, `close`, `reopen`, and `rename`.
* `diff`: View changes in a PR.
* `merge`: Finalize pull requests.
### Infrastructure & Compute
* **Endpoints:** Deploy and manage Inference Endpoints (`deploy`, `pause`, `resume`, `scale-to-zero`, `catalog`).
* **Jobs:** Run compute tasks on HF infrastructure. Includes `hf jobs uv` for running Python scripts with inline dependencies and `stats` for resource monitoring.
* **Spaces:** Manage interactive apps. Includes `dev-mode` and `hot-reload` for Python files without full restarts.
### Storage & Automation
* **Buckets:** Full S3-like bucket management (`create`, `cp`, `mv`, `rm`, `sync`).
* **Cache:** Manage local storage with `list`, `prune` (remove detached revisions), and `verify` (checksum checks).
* **Webhooks:** Automate workflows by managing Hub webhooks (`create`, `watch`, `enable`/`disable`).
* **Collections:** Organize Hub items into collections (`add-item`, `update`, `list`).
---
## Advanced Usage & Tips
### Global Flags
* `--format json`: Produces machine-readable output for automation.
* `-q` / `--quiet`: Limits output to IDs only.
### Extensions & Skills
* **Extensions:** Extend CLI functionality via GitHub repositories using `hf extensions install REPO_ID`.
* **Skills:** Manage AI assistant skills with `hf skills add`.

View file

@ -0,0 +1,3 @@
---
description: Model serving, quantization (GGUF/GPTQ), structured output, inference optimization, and model surgery tools for deploying and running LLMs.
---

View file

@ -0,0 +1,430 @@
---
name: gguf-quantization
description: GGUF format and llama.cpp quantization for efficient CPU/GPU inference. Use when deploying models on consumer hardware, Apple Silicon, or when needing flexible quantization from 2-8 bit without GPU requirements.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [llama-cpp-python>=0.2.0]
metadata:
hermes:
tags: [GGUF, Quantization, llama.cpp, CPU Inference, Apple Silicon, Model Compression, Optimization]
---
# GGUF - Quantization Format for llama.cpp
The GGUF (GPT-Generated Unified Format) is the standard file format for llama.cpp, enabling efficient inference on CPUs, Apple Silicon, and GPUs with flexible quantization options.
## When to use GGUF
**Use GGUF when:**
- Deploying on consumer hardware (laptops, desktops)
- Running on Apple Silicon (M1/M2/M3) with Metal acceleration
- Need CPU inference without GPU requirements
- Want flexible quantization (Q2_K to Q8_0)
- Using local AI tools (LM Studio, Ollama, text-generation-webui)
**Key advantages:**
- **Universal hardware**: CPU, Apple Silicon, NVIDIA, AMD support
- **No Python runtime**: Pure C/C++ inference
- **Flexible quantization**: 2-8 bit with various methods (K-quants)
- **Ecosystem support**: LM Studio, Ollama, koboldcpp, and more
- **imatrix**: Importance matrix for better low-bit quality
**Use alternatives instead:**
- **AWQ/GPTQ**: Maximum accuracy with calibration on NVIDIA GPUs
- **HQQ**: Fast calibration-free quantization for HuggingFace
- **bitsandbytes**: Simple integration with transformers library
- **TensorRT-LLM**: Production NVIDIA deployment with maximum speed
## Quick start
### Installation
```bash
# Clone llama.cpp
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
# Build (CPU)
make
# Build with CUDA (NVIDIA)
make GGML_CUDA=1
# Build with Metal (Apple Silicon)
make GGML_METAL=1
# Install Python bindings (optional)
pip install llama-cpp-python
```
### Convert model to GGUF
```bash
# Install requirements
pip install -r requirements.txt
# Convert HuggingFace model to GGUF (FP16)
python convert_hf_to_gguf.py ./path/to/model --outfile model-f16.gguf
# Or specify output type
python convert_hf_to_gguf.py ./path/to/model \
--outfile model-f16.gguf \
--outtype f16
```
### Quantize model
```bash
# Basic quantization to Q4_K_M
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
# Quantize with importance matrix (better quality)
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
```
### Run inference
```bash
# CLI inference
./llama-cli -m model-q4_k_m.gguf -p "Hello, how are you?"
# Interactive mode
./llama-cli -m model-q4_k_m.gguf --interactive
# With GPU offload
./llama-cli -m model-q4_k_m.gguf -ngl 35 -p "Hello!"
```
## Quantization types
### K-quant methods (recommended)
| Type | Bits | Size (7B) | Quality | Use Case |
|------|------|-----------|---------|----------|
| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression |
| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained |
| Q3_K_M | 3.3 | ~3.3 GB | Medium | Balance |
| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Good balance |
| Q4_K_M | 4.5 | ~4.1 GB | High | **Recommended default** |
| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused |
| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality |
| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original |
| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality |
### Legacy methods
| Type | Description |
|------|-------------|
| Q4_0 | 4-bit, basic |
| Q4_1 | 4-bit with delta |
| Q5_0 | 5-bit, basic |
| Q5_1 | 5-bit with delta |
**Recommendation**: Use K-quant methods (Q4_K_M, Q5_K_M) for best quality/size ratio.
## Conversion workflows
### Workflow 1: HuggingFace to GGUF
```bash
# 1. Download model
huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b
# 2. Convert to GGUF (FP16)
python convert_hf_to_gguf.py ./llama-3.1-8b \
--outfile llama-3.1-8b-f16.gguf \
--outtype f16
# 3. Quantize
./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M
# 4. Test
./llama-cli -m llama-3.1-8b-q4_k_m.gguf -p "Hello!" -n 50
```
### Workflow 2: With importance matrix (better quality)
```bash
# 1. Convert to GGUF
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf
# 2. Create calibration text (diverse samples)
cat > calibration.txt << 'EOF'
The quick brown fox jumps over the lazy dog.
Machine learning is a subset of artificial intelligence.
Python is a popular programming language.
# Add more diverse text samples...
EOF
# 3. Generate importance matrix
./llama-imatrix -m model-f16.gguf \
-f calibration.txt \
--chunk 512 \
-o model.imatrix \
-ngl 35 # GPU layers if available
# 4. Quantize with imatrix
./llama-quantize --imatrix model.imatrix \
model-f16.gguf \
model-q4_k_m.gguf \
Q4_K_M
```
### Workflow 3: Multiple quantizations
```bash
#!/bin/bash
MODEL="llama-3.1-8b-f16.gguf"
IMATRIX="llama-3.1-8b.imatrix"
# Generate imatrix once
./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35
# Create multiple quantizations
for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do
OUTPUT="llama-3.1-8b-${QUANT,,}.gguf"
./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT
echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))"
done
```
## Python usage
### llama-cpp-python
```python
from llama_cpp import Llama
# Load model
llm = Llama(
model_path="./model-q4_k_m.gguf",
n_ctx=4096, # Context window
n_gpu_layers=35, # GPU offload (0 for CPU only)
n_threads=8 # CPU threads
)
# Generate
output = llm(
"What is machine learning?",
max_tokens=256,
temperature=0.7,
stop=["</s>", "\n\n"]
)
print(output["choices"][0]["text"])
```
### Chat completion
```python
from llama_cpp import Llama
llm = Llama(
model_path="./model-q4_k_m.gguf",
n_ctx=4096,
n_gpu_layers=35,
chat_format="llama-3" # Or "chatml", "mistral", etc.
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is Python?"}
]
response = llm.create_chat_completion(
messages=messages,
max_tokens=256,
temperature=0.7
)
print(response["choices"][0]["message"]["content"])
```
### Streaming
```python
from llama_cpp import Llama
llm = Llama(model_path="./model-q4_k_m.gguf", n_gpu_layers=35)
# Stream tokens
for chunk in llm(
"Explain quantum computing:",
max_tokens=256,
stream=True
):
print(chunk["choices"][0]["text"], end="", flush=True)
```
## Server mode
### Start OpenAI-compatible server
```bash
# Start server
./llama-server -m model-q4_k_m.gguf \
--host 0.0.0.0 \
--port 8080 \
-ngl 35 \
-c 4096
# Or with Python bindings
python -m llama_cpp.server \
--model model-q4_k_m.gguf \
--n_gpu_layers 35 \
--host 0.0.0.0 \
--port 8080
```
### Use with OpenAI client
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8080/v1",
api_key="not-needed"
)
response = client.chat.completions.create(
model="local-model",
messages=[{"role": "user", "content": "Hello!"}],
max_tokens=256
)
print(response.choices[0].message.content)
```
## Hardware optimization
### Apple Silicon (Metal)
```bash
# Build with Metal
make clean && make GGML_METAL=1
# Run with Metal acceleration
./llama-cli -m model.gguf -ngl 99 -p "Hello"
# Python with Metal
llm = Llama(
model_path="model.gguf",
n_gpu_layers=99, # Offload all layers
n_threads=1 # Metal handles parallelism
)
```
### NVIDIA CUDA
```bash
# Build with CUDA
make clean && make GGML_CUDA=1
# Run with CUDA
./llama-cli -m model.gguf -ngl 35 -p "Hello"
# Specify GPU
CUDA_VISIBLE_DEVICES=0 ./llama-cli -m model.gguf -ngl 35
```
### CPU optimization
```bash
# Build with AVX2/AVX512
make clean && make
# Run with optimal threads
./llama-cli -m model.gguf -t 8 -p "Hello"
# Python CPU config
llm = Llama(
model_path="model.gguf",
n_gpu_layers=0, # CPU only
n_threads=8, # Match physical cores
n_batch=512 # Batch size for prompt processing
)
```
## Integration with tools
### Ollama
```bash
# Create Modelfile
cat > Modelfile << 'EOF'
FROM ./model-q4_k_m.gguf
TEMPLATE """{{ .System }}
{{ .Prompt }}"""
PARAMETER temperature 0.7
PARAMETER num_ctx 4096
EOF
# Create Ollama model
ollama create mymodel -f Modelfile
# Run
ollama run mymodel "Hello!"
```
### LM Studio
1. Place GGUF file in `~/.cache/lm-studio/models/`
2. Open LM Studio and select the model
3. Configure context length and GPU offload
4. Start inference
### text-generation-webui
```bash
# Place in models folder
cp model-q4_k_m.gguf text-generation-webui/models/
# Start with llama.cpp loader
python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35
```
## Best practices
1. **Use K-quants**: Q4_K_M offers best quality/size balance
2. **Use imatrix**: Always use importance matrix for Q4 and below
3. **GPU offload**: Offload as many layers as VRAM allows
4. **Context length**: Start with 4096, increase if needed
5. **Thread count**: Match physical CPU cores, not logical
6. **Batch size**: Increase n_batch for faster prompt processing
## Common issues
**Model loads slowly:**
```bash
# Use mmap for faster loading
./llama-cli -m model.gguf --mmap
```
**Out of memory:**
```bash
# Reduce GPU layers
./llama-cli -m model.gguf -ngl 20 # Reduce from 35
# Or use smaller quantization
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
```
**Poor quality at low bits:**
```bash
# Always use imatrix for Q4 and below
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
```
## References
- **[Advanced Usage](references/advanced-usage.md)** - Batching, speculative decoding, custom builds
- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, benchmarks
## Resources
- **Repository**: https://github.com/ggml-org/llama.cpp
- **Python Bindings**: https://github.com/abetlen/llama-cpp-python
- **Pre-quantized Models**: https://huggingface.co/TheBloke
- **GGUF Converter**: https://huggingface.co/spaces/ggml-org/gguf-my-repo
- **License**: MIT

View file

@ -0,0 +1,504 @@
# GGUF Advanced Usage Guide
## Speculative Decoding
### Draft Model Approach
```bash
# Use smaller model as draft for faster generation
./llama-speculative \
-m large-model-q4_k_m.gguf \
-md draft-model-q4_k_m.gguf \
-p "Write a story about AI" \
-n 500 \
--draft 8 # Draft tokens before verification
```
### Self-Speculative Decoding
```bash
# Use same model with different context for speculation
./llama-cli -m model-q4_k_m.gguf \
--lookup-cache-static lookup.bin \
--lookup-cache-dynamic lookup-dynamic.bin \
-p "Hello world"
```
## Batched Inference
### Process Multiple Prompts
```python
from llama_cpp import Llama
llm = Llama(
model_path="model-q4_k_m.gguf",
n_ctx=4096,
n_gpu_layers=35,
n_batch=512 # Larger batch for parallel processing
)
prompts = [
"What is Python?",
"Explain machine learning.",
"Describe neural networks."
]
# Process in batch (each prompt gets separate context)
for prompt in prompts:
output = llm(prompt, max_tokens=100)
print(f"Q: {prompt}")
print(f"A: {output['choices'][0]['text']}\n")
```
### Server Batching
```bash
# Start server with batching
./llama-server -m model-q4_k_m.gguf \
--host 0.0.0.0 \
--port 8080 \
-ngl 35 \
-c 4096 \
--parallel 4 # Concurrent requests
--cont-batching # Continuous batching
```
## Custom Model Conversion
### Convert with Vocabulary Modifications
```python
# custom_convert.py
import sys
sys.path.insert(0, './llama.cpp')
from convert_hf_to_gguf import main
from gguf import GGUFWriter
# Custom conversion with modified vocab
def convert_with_custom_vocab(model_path, output_path):
# Load and modify tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Add special tokens if needed
special_tokens = {"additional_special_tokens": ["<|custom|>"]}
tokenizer.add_special_tokens(special_tokens)
tokenizer.save_pretrained(model_path)
# Then run standard conversion
main([model_path, "--outfile", output_path])
```
### Convert Specific Architecture
```bash
# For Mistral-style models
python convert_hf_to_gguf.py ./mistral-model \
--outfile mistral-f16.gguf \
--outtype f16
# For Qwen models
python convert_hf_to_gguf.py ./qwen-model \
--outfile qwen-f16.gguf \
--outtype f16
# For Phi models
python convert_hf_to_gguf.py ./phi-model \
--outfile phi-f16.gguf \
--outtype f16
```
## Advanced Quantization
### Mixed Quantization
```bash
# Quantize different layer types differently
./llama-quantize model-f16.gguf model-mixed.gguf Q4_K_M \
--allow-requantize \
--leave-output-tensor
```
### Quantization with Token Embeddings
```bash
# Keep embeddings at higher precision
./llama-quantize model-f16.gguf model-q4.gguf Q4_K_M \
--token-embedding-type f16
```
### IQ Quantization (Importance-aware)
```bash
# Ultra-low bit quantization with importance
./llama-quantize --imatrix model.imatrix \
model-f16.gguf model-iq2_xxs.gguf IQ2_XXS
# Available IQ types: IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_XS, IQ3_S, IQ4_XS
```
## Memory Optimization
### Memory Mapping
```python
from llama_cpp import Llama
# Use memory mapping for large models
llm = Llama(
model_path="model-q4_k_m.gguf",
use_mmap=True, # Memory map the model
use_mlock=False, # Don't lock in RAM
n_gpu_layers=35
)
```
### Partial GPU Offload
```python
# Calculate layers to offload based on VRAM
import subprocess
def get_free_vram_gb():
result = subprocess.run(
['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
capture_output=True, text=True
)
return int(result.stdout.strip()) / 1024
# Estimate layers based on VRAM (rough: 0.5GB per layer for 7B Q4)
free_vram = get_free_vram_gb()
layers_to_offload = int(free_vram / 0.5)
llm = Llama(
model_path="model-q4_k_m.gguf",
n_gpu_layers=min(layers_to_offload, 35) # Cap at total layers
)
```
### KV Cache Optimization
```python
from llama_cpp import Llama
# Optimize KV cache for long contexts
llm = Llama(
model_path="model-q4_k_m.gguf",
n_ctx=8192, # Large context
n_gpu_layers=35,
type_k=1, # Q8_0 for K cache (1)
type_v=1, # Q8_0 for V cache (1)
# Or use Q4_0 (2) for more compression
)
```
## Context Management
### Context Shifting
```python
from llama_cpp import Llama
llm = Llama(
model_path="model-q4_k_m.gguf",
n_ctx=4096,
n_gpu_layers=35
)
# Handle long conversations with context shifting
conversation = []
max_history = 10
def chat(user_message):
conversation.append({"role": "user", "content": user_message})
# Keep only recent history
if len(conversation) > max_history * 2:
conversation = conversation[-max_history * 2:]
response = llm.create_chat_completion(
messages=conversation,
max_tokens=256
)
assistant_message = response["choices"][0]["message"]["content"]
conversation.append({"role": "assistant", "content": assistant_message})
return assistant_message
```
### Save and Load State
```bash
# Save state to file
./llama-cli -m model.gguf \
-p "Once upon a time" \
--save-session session.bin \
-n 100
# Load and continue
./llama-cli -m model.gguf \
--load-session session.bin \
-p " and they lived" \
-n 100
```
## Grammar Constrained Generation
### JSON Output
```python
from llama_cpp import Llama, LlamaGrammar
# Define JSON grammar
json_grammar = LlamaGrammar.from_string('''
root ::= object
object ::= "{" ws pair ("," ws pair)* "}" ws
pair ::= string ":" ws value
value ::= string | number | object | array | "true" | "false" | "null"
array ::= "[" ws value ("," ws value)* "]" ws
string ::= "\\"" [^"\\\\]* "\\""
number ::= [0-9]+
ws ::= [ \\t\\n]*
''')
llm = Llama(model_path="model-q4_k_m.gguf", n_gpu_layers=35)
output = llm(
"Output a JSON object with name and age:",
grammar=json_grammar,
max_tokens=100
)
print(output["choices"][0]["text"])
```
### Custom Grammar
```python
# Grammar for specific format
answer_grammar = LlamaGrammar.from_string('''
root ::= "Answer: " letter "\\n" "Explanation: " explanation
letter ::= [A-D]
explanation ::= [a-zA-Z0-9 .,!?]+
''')
output = llm(
"Q: What is 2+2? A) 3 B) 4 C) 5 D) 6",
grammar=answer_grammar,
max_tokens=100
)
```
## LoRA Integration
### Load LoRA Adapter
```bash
# Apply LoRA at runtime
./llama-cli -m base-model-q4_k_m.gguf \
--lora lora-adapter.gguf \
--lora-scale 1.0 \
-p "Hello!"
```
### Multiple LoRA Adapters
```bash
# Stack multiple adapters
./llama-cli -m base-model.gguf \
--lora adapter1.gguf --lora-scale 0.5 \
--lora adapter2.gguf --lora-scale 0.5 \
-p "Hello!"
```
### Python LoRA Usage
```python
from llama_cpp import Llama
llm = Llama(
model_path="base-model-q4_k_m.gguf",
lora_path="lora-adapter.gguf",
lora_scale=1.0,
n_gpu_layers=35
)
```
## Embedding Generation
### Extract Embeddings
```python
from llama_cpp import Llama
llm = Llama(
model_path="model-q4_k_m.gguf",
embedding=True, # Enable embedding mode
n_gpu_layers=35
)
# Get embeddings
embeddings = llm.embed("This is a test sentence.")
print(f"Embedding dimension: {len(embeddings)}")
```
### Batch Embeddings
```python
texts = [
"Machine learning is fascinating.",
"Deep learning uses neural networks.",
"Python is a programming language."
]
embeddings = [llm.embed(text) for text in texts]
# Calculate similarity
import numpy as np
def cosine_similarity(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
sim = cosine_similarity(embeddings[0], embeddings[1])
print(f"Similarity: {sim:.4f}")
```
## Performance Tuning
### Benchmark Script
```python
import time
from llama_cpp import Llama
def benchmark(model_path, prompt, n_tokens=100, n_runs=5):
llm = Llama(
model_path=model_path,
n_gpu_layers=35,
n_ctx=2048,
verbose=False
)
# Warmup
llm(prompt, max_tokens=10)
# Benchmark
times = []
for _ in range(n_runs):
start = time.time()
output = llm(prompt, max_tokens=n_tokens)
elapsed = time.time() - start
times.append(elapsed)
avg_time = sum(times) / len(times)
tokens_per_sec = n_tokens / avg_time
print(f"Model: {model_path}")
print(f"Avg time: {avg_time:.2f}s")
print(f"Tokens/sec: {tokens_per_sec:.1f}")
return tokens_per_sec
# Compare quantizations
for quant in ["q4_k_m", "q5_k_m", "q8_0"]:
benchmark(f"model-{quant}.gguf", "Explain quantum computing:", 100)
```
### Optimal Configuration Finder
```python
def find_optimal_config(model_path, target_vram_gb=8):
"""Find optimal n_gpu_layers and n_batch for target VRAM."""
from llama_cpp import Llama
import gc
best_config = None
best_speed = 0
for n_gpu_layers in range(0, 50, 5):
for n_batch in [128, 256, 512, 1024]:
try:
gc.collect()
llm = Llama(
model_path=model_path,
n_gpu_layers=n_gpu_layers,
n_batch=n_batch,
n_ctx=2048,
verbose=False
)
# Quick benchmark
start = time.time()
llm("Hello", max_tokens=50)
speed = 50 / (time.time() - start)
if speed > best_speed:
best_speed = speed
best_config = {
"n_gpu_layers": n_gpu_layers,
"n_batch": n_batch,
"speed": speed
}
del llm
gc.collect()
except Exception as e:
print(f"OOM at layers={n_gpu_layers}, batch={n_batch}")
break
return best_config
```
## Multi-GPU Setup
### Distribute Across GPUs
```bash
# Split model across multiple GPUs
./llama-cli -m large-model.gguf \
--tensor-split 0.5,0.5 \
-ngl 60 \
-p "Hello!"
```
### Python Multi-GPU
```python
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
from llama_cpp import Llama
llm = Llama(
model_path="large-model-q4_k_m.gguf",
n_gpu_layers=60,
tensor_split=[0.5, 0.5] # Split evenly across 2 GPUs
)
```
## Custom Builds
### Build with All Optimizations
```bash
# Clean build with all CPU optimizations
make clean
LLAMA_OPENBLAS=1 LLAMA_BLAS_VENDOR=OpenBLAS make -j
# With CUDA and cuBLAS
make clean
GGML_CUDA=1 LLAMA_CUBLAS=1 make -j
# With specific CUDA architecture
GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_86 make -j
```
### CMake Build
```bash
mkdir build && cd build
cmake .. -DGGML_CUDA=ON -DCMAKE_BUILD_TYPE=Release
cmake --build . --config Release -j
```

View file

@ -0,0 +1,442 @@
# GGUF Troubleshooting Guide
## Installation Issues
### Build Fails
**Error**: `make: *** No targets specified and no makefile found`
**Fix**:
```bash
# Ensure you're in llama.cpp directory
cd llama.cpp
make
```
**Error**: `fatal error: cuda_runtime.h: No such file or directory`
**Fix**:
```bash
# Install CUDA toolkit
# Ubuntu
sudo apt install nvidia-cuda-toolkit
# Or set CUDA path
export CUDA_PATH=/usr/local/cuda
export PATH=$CUDA_PATH/bin:$PATH
make GGML_CUDA=1
```
### Python Bindings Issues
**Error**: `ERROR: Failed building wheel for llama-cpp-python`
**Fix**:
```bash
# Install build dependencies
pip install cmake scikit-build-core
# For CUDA support
CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
# For Metal (macOS)
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
```
**Error**: `ImportError: libcudart.so.XX: cannot open shared object file`
**Fix**:
```bash
# Add CUDA libraries to path
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Or reinstall with correct CUDA version
pip uninstall llama-cpp-python
CUDACXX=/usr/local/cuda/bin/nvcc CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python
```
## Conversion Issues
### Model Not Supported
**Error**: `KeyError: 'model.embed_tokens.weight'`
**Fix**:
```bash
# Check model architecture
python -c "from transformers import AutoConfig; print(AutoConfig.from_pretrained('./model').architectures)"
# Use appropriate conversion script
# For most models:
python convert_hf_to_gguf.py ./model --outfile model.gguf
# For older models, check if legacy script needed
```
### Vocabulary Mismatch
**Error**: `RuntimeError: Vocabulary size mismatch`
**Fix**:
```python
# Ensure tokenizer matches model
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("./model")
model = AutoModelForCausalLM.from_pretrained("./model")
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Model vocab size: {model.config.vocab_size}")
# If mismatch, resize embeddings before conversion
model.resize_token_embeddings(len(tokenizer))
model.save_pretrained("./model-fixed")
```
### Out of Memory During Conversion
**Error**: `torch.cuda.OutOfMemoryError` during conversion
**Fix**:
```bash
# Use CPU for conversion
CUDA_VISIBLE_DEVICES="" python convert_hf_to_gguf.py ./model --outfile model.gguf
# Or use low memory mode
python convert_hf_to_gguf.py ./model --outfile model.gguf --outtype f16
```
## Quantization Issues
### Wrong Output File Size
**Problem**: Quantized file is larger than expected
**Check**:
```bash
# Verify quantization type
./llama-cli -m model.gguf --verbose
# Expected sizes for 7B model:
# Q4_K_M: ~4.1 GB
# Q5_K_M: ~4.8 GB
# Q8_0: ~7.2 GB
# F16: ~13.5 GB
```
### Quantization Crashes
**Error**: `Segmentation fault` during quantization
**Fix**:
```bash
# Increase stack size
ulimit -s unlimited
# Or use less threads
./llama-quantize -t 4 model-f16.gguf model-q4.gguf Q4_K_M
```
### Poor Quality After Quantization
**Problem**: Model outputs gibberish after quantization
**Solutions**:
1. **Use importance matrix**:
```bash
# Generate imatrix with good calibration data
./llama-imatrix -m model-f16.gguf \
-f wiki_sample.txt \
--chunk 512 \
-o model.imatrix
# Quantize with imatrix
./llama-quantize --imatrix model.imatrix \
model-f16.gguf model-q4_k_m.gguf Q4_K_M
```
2. **Try higher precision**:
```bash
# Use Q5_K_M or Q6_K instead of Q4
./llama-quantize model-f16.gguf model-q5_k_m.gguf Q5_K_M
```
3. **Check original model**:
```bash
# Test FP16 version first
./llama-cli -m model-f16.gguf -p "Hello, how are you?" -n 50
```
## Inference Issues
### Slow Generation
**Problem**: Generation is slower than expected
**Solutions**:
1. **Enable GPU offload**:
```bash
./llama-cli -m model.gguf -ngl 35 -p "Hello"
```
2. **Optimize batch size**:
```python
llm = Llama(
model_path="model.gguf",
n_batch=512, # Increase for faster prompt processing
n_gpu_layers=35
)
```
3. **Use appropriate threads**:
```bash
# Match physical cores, not logical
./llama-cli -m model.gguf -t 8 -p "Hello"
```
4. **Enable Flash Attention** (if supported):
```bash
./llama-cli -m model.gguf -ngl 35 --flash-attn -p "Hello"
```
### Out of Memory
**Error**: `CUDA out of memory` or system freeze
**Solutions**:
1. **Reduce GPU layers**:
```python
# Start low and increase
llm = Llama(model_path="model.gguf", n_gpu_layers=10)
```
2. **Use smaller quantization**:
```bash
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
```
3. **Reduce context length**:
```python
llm = Llama(
model_path="model.gguf",
n_ctx=2048, # Reduce from 4096
n_gpu_layers=35
)
```
4. **Quantize KV cache**:
```python
llm = Llama(
model_path="model.gguf",
type_k=2, # Q4_0 for K cache
type_v=2, # Q4_0 for V cache
n_gpu_layers=35
)
```
### Garbage Output
**Problem**: Model outputs random characters or nonsense
**Diagnose**:
```python
# Check model loading
llm = Llama(model_path="model.gguf", verbose=True)
# Test with simple prompt
output = llm("1+1=", max_tokens=5, temperature=0)
print(output)
```
**Solutions**:
1. **Check model integrity**:
```bash
# Verify GGUF file
./llama-cli -m model.gguf --verbose 2>&1 | head -50
```
2. **Use correct chat format**:
```python
llm = Llama(
model_path="model.gguf",
chat_format="llama-3" # Match your model: chatml, mistral, etc.
)
```
3. **Check temperature**:
```python
# Use lower temperature for deterministic output
output = llm("Hello", max_tokens=50, temperature=0.1)
```
### Token Issues
**Error**: `RuntimeError: unknown token` or encoding errors
**Fix**:
```python
# Ensure UTF-8 encoding
prompt = "Hello, world!".encode('utf-8').decode('utf-8')
output = llm(prompt, max_tokens=50)
```
## Server Issues
### Connection Refused
**Error**: `Connection refused` when accessing server
**Fix**:
```bash
# Bind to all interfaces
./llama-server -m model.gguf --host 0.0.0.0 --port 8080
# Check if port is in use
lsof -i :8080
```
### Server Crashes Under Load
**Problem**: Server crashes with multiple concurrent requests
**Solutions**:
1. **Limit parallelism**:
```bash
./llama-server -m model.gguf \
--parallel 2 \
-c 4096 \
--cont-batching
```
2. **Add request timeout**:
```bash
./llama-server -m model.gguf --timeout 300
```
3. **Monitor memory**:
```bash
watch -n 1 nvidia-smi # For GPU
watch -n 1 free -h # For RAM
```
### API Compatibility Issues
**Problem**: OpenAI client not working with server
**Fix**:
```python
from openai import OpenAI
# Use correct base URL format
client = OpenAI(
base_url="http://localhost:8080/v1", # Include /v1
api_key="not-needed"
)
# Use correct model name
response = client.chat.completions.create(
model="local", # Or the actual model name
messages=[{"role": "user", "content": "Hello"}]
)
```
## Apple Silicon Issues
### Metal Not Working
**Problem**: Metal acceleration not enabled
**Check**:
```bash
# Verify Metal support
./llama-cli -m model.gguf --verbose 2>&1 | grep -i metal
```
**Fix**:
```bash
# Rebuild with Metal
make clean
make GGML_METAL=1
# Python bindings
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall
```
### Incorrect Memory Usage on M1/M2
**Problem**: Model uses too much unified memory
**Fix**:
```python
# Offload all layers for Metal
llm = Llama(
model_path="model.gguf",
n_gpu_layers=99, # Offload everything
n_threads=1 # Metal handles parallelism
)
```
## Debugging
### Enable Verbose Output
```bash
# CLI verbose mode
./llama-cli -m model.gguf --verbose -p "Hello" -n 50
# Python verbose
llm = Llama(model_path="model.gguf", verbose=True)
```
### Check Model Metadata
```bash
# View GGUF metadata
./llama-cli -m model.gguf --verbose 2>&1 | head -100
```
### Validate GGUF File
```python
import struct
def validate_gguf(filepath):
with open(filepath, 'rb') as f:
magic = f.read(4)
if magic != b'GGUF':
print(f"Invalid magic: {magic}")
return False
version = struct.unpack('<I', f.read(4))[0]
print(f"GGUF version: {version}")
tensor_count = struct.unpack('<Q', f.read(8))[0]
metadata_count = struct.unpack('<Q', f.read(8))[0]
print(f"Tensors: {tensor_count}, Metadata: {metadata_count}")
return True
validate_gguf("model.gguf")
```
## Getting Help
1. **GitHub Issues**: https://github.com/ggml-org/llama.cpp/issues
2. **Discussions**: https://github.com/ggml-org/llama.cpp/discussions
3. **Reddit**: r/LocalLLaMA
### Reporting Issues
Include:
- llama.cpp version/commit hash
- Build command used
- Model name and quantization
- Full error message/stack trace
- Hardware: CPU/GPU model, RAM, VRAM
- OS version
- Minimal reproduction steps

View file

@ -0,0 +1,575 @@
---
name: guidance
description: Control LLM output with regex and grammars, guarantee valid JSON/XML/code generation, enforce structured formats, and build multi-step workflows with Guidance - Microsoft Research's constrained generation framework
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [guidance, transformers]
metadata:
hermes:
tags: [Prompt Engineering, Guidance, Constrained Generation, Structured Output, JSON Validation, Grammar, Microsoft Research, Format Enforcement, Multi-Step Workflows]
---
# Guidance: Constrained LLM Generation
## When to Use This Skill
Use Guidance when you need to:
- **Control LLM output syntax** with regex or grammars
- **Guarantee valid JSON/XML/code** generation
- **Reduce latency** vs traditional prompting approaches
- **Enforce structured formats** (dates, emails, IDs, etc.)
- **Build multi-step workflows** with Pythonic control flow
- **Prevent invalid outputs** through grammatical constraints
**GitHub Stars**: 18,000+ | **From**: Microsoft Research
## Installation
```bash
# Base installation
pip install guidance
# With specific backends
pip install guidance[transformers] # Hugging Face models
pip install guidance[llama_cpp] # llama.cpp models
```
## Quick Start
### Basic Example: Structured Generation
```python
from guidance import models, gen
# Load model (supports OpenAI, Transformers, llama.cpp)
lm = models.OpenAI("gpt-4")
# Generate with constraints
result = lm + "The capital of France is " + gen("capital", max_tokens=5)
print(result["capital"]) # "Paris"
```
### With Anthropic Claude
```python
from guidance import models, gen, system, user, assistant
# Configure Claude
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Use context managers for chat format
with system():
lm += "You are a helpful assistant."
with user():
lm += "What is the capital of France?"
with assistant():
lm += gen(max_tokens=20)
```
## Core Concepts
### 1. Context Managers
Guidance uses Pythonic context managers for chat-style interactions.
```python
from guidance import system, user, assistant, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# System message
with system():
lm += "You are a JSON generation expert."
# User message
with user():
lm += "Generate a person object with name and age."
# Assistant response
with assistant():
lm += gen("response", max_tokens=100)
print(lm["response"])
```
**Benefits:**
- Natural chat flow
- Clear role separation
- Easy to read and maintain
### 2. Constrained Generation
Guidance ensures outputs match specified patterns using regex or grammars.
#### Regex Constraints
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Constrain to valid email format
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
# Constrain to date format (YYYY-MM-DD)
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
# Constrain to phone number
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
print(lm["email"]) # Guaranteed valid email
print(lm["date"]) # Guaranteed YYYY-MM-DD format
```
**How it works:**
- Regex converted to grammar at token level
- Invalid tokens filtered during generation
- Model can only produce matching outputs
#### Selection Constraints
```python
from guidance import models, gen, select
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Constrain to specific choices
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
# Multiple-choice selection
lm += "Best answer: " + select(
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
name="answer"
)
print(lm["sentiment"]) # One of: positive, negative, neutral
print(lm["answer"]) # One of: A, B, C, or D
```
### 3. Token Healing
Guidance automatically "heals" token boundaries between prompt and generation.
**Problem:** Tokenization creates unnatural boundaries.
```python
# Without token healing
prompt = "The capital of France is "
# Last token: " is "
# First generated token might be " Par" (with leading space)
# Result: "The capital of France is Paris" (double space!)
```
**Solution:** Guidance backs up one token and regenerates.
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Token healing enabled by default
lm += "The capital of France is " + gen("capital", max_tokens=5)
# Result: "The capital of France is Paris" (correct spacing)
```
**Benefits:**
- Natural text boundaries
- No awkward spacing issues
- Better model performance (sees natural token sequences)
### 4. Grammar-Based Generation
Define complex structures using context-free grammars.
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# JSON grammar (simplified)
json_grammar = """
{
"name": <gen name regex="[A-Za-z ]+" max_tokens=20>,
"age": <gen age regex="[0-9]+" max_tokens=3>,
"email": <gen email regex="[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}" max_tokens=50>
}
"""
# Generate valid JSON
lm += gen("person", grammar=json_grammar)
print(lm["person"]) # Guaranteed valid JSON structure
```
**Use cases:**
- Complex structured outputs
- Nested data structures
- Programming language syntax
- Domain-specific languages
### 5. Guidance Functions
Create reusable generation patterns with the `@guidance` decorator.
```python
from guidance import guidance, gen, models
@guidance
def generate_person(lm):
"""Generate a person with name and age."""
lm += "Name: " + gen("name", max_tokens=20, stop="\n")
lm += "\nAge: " + gen("age", regex=r"[0-9]+", max_tokens=3)
return lm
# Use the function
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_person(lm)
print(lm["name"])
print(lm["age"])
```
**Stateful Functions:**
```python
@guidance(stateless=False)
def react_agent(lm, question, tools, max_rounds=5):
"""ReAct agent with tool use."""
lm += f"Question: {question}\n\n"
for i in range(max_rounds):
# Thought
lm += f"Thought {i+1}: " + gen("thought", stop="\n")
# Action
lm += "\nAction: " + select(list(tools.keys()), name="action")
# Execute tool
tool_result = tools[lm["action"]]()
lm += f"\nObservation: {tool_result}\n\n"
# Check if done
lm += "Done? " + select(["Yes", "No"], name="done")
if lm["done"] == "Yes":
break
# Final answer
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
return lm
```
## Backend Configuration
### Anthropic Claude
```python
from guidance import models
lm = models.Anthropic(
model="claude-sonnet-4-5-20250929",
api_key="your-api-key" # Or set ANTHROPIC_API_KEY env var
)
```
### OpenAI
```python
lm = models.OpenAI(
model="gpt-4o-mini",
api_key="your-api-key" # Or set OPENAI_API_KEY env var
)
```
### Local Models (Transformers)
```python
from guidance.models import Transformers
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda" # Or "cpu"
)
```
### Local Models (llama.cpp)
```python
from guidance.models import LlamaCpp
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096,
n_gpu_layers=35
)
```
## Common Patterns
### Pattern 1: JSON Generation
```python
from guidance import models, gen, system, user, assistant
lm = models.Anthropic("claude-sonnet-4-5-20250929")
with system():
lm += "You generate valid JSON."
with user():
lm += "Generate a user profile with name, age, and email."
with assistant():
lm += """{
"name": """ + gen("name", regex=r'"[A-Za-z ]+"', max_tokens=30) + """,
"age": """ + gen("age", regex=r"[0-9]+", max_tokens=3) + """,
"email": """ + gen("email", regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"', max_tokens=50) + """
}"""
print(lm) # Valid JSON guaranteed
```
### Pattern 2: Classification
```python
from guidance import models, gen, select
lm = models.Anthropic("claude-sonnet-4-5-20250929")
text = "This product is amazing! I love it."
lm += f"Text: {text}\n"
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]+", max_tokens=3) + "%"
print(f"Sentiment: {lm['sentiment']}")
print(f"Confidence: {lm['confidence']}%")
```
### Pattern 3: Multi-Step Reasoning
```python
from guidance import models, gen, guidance
@guidance
def chain_of_thought(lm, question):
"""Generate answer with step-by-step reasoning."""
lm += f"Question: {question}\n\n"
# Generate multiple reasoning steps
for i in range(3):
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
# Final answer
lm += "\nTherefore, the answer is: " + gen("answer", max_tokens=50)
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = chain_of_thought(lm, "What is 15% of 200?")
print(lm["answer"])
```
### Pattern 4: ReAct Agent
```python
from guidance import models, gen, select, guidance
@guidance(stateless=False)
def react_agent(lm, question):
"""ReAct agent with tool use."""
tools = {
"calculator": lambda expr: eval(expr),
"search": lambda query: f"Search results for: {query}",
}
lm += f"Question: {question}\n\n"
for round in range(5):
# Thought
lm += f"Thought: " + gen("thought", stop="\n") + "\n"
# Action selection
lm += "Action: " + select(["calculator", "search", "answer"], name="action")
if lm["action"] == "answer":
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
break
# Action input
lm += "\nAction Input: " + gen("action_input", stop="\n") + "\n"
# Execute tool
if lm["action"] in tools:
result = tools[lm["action"]](lm["action_input"])
lm += f"Observation: {result}\n\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = react_agent(lm, "What is 25 * 4 + 10?")
print(lm["answer"])
```
### Pattern 5: Data Extraction
```python
from guidance import models, gen, guidance
@guidance
def extract_entities(lm, text):
"""Extract structured entities from text."""
lm += f"Text: {text}\n\n"
# Extract person
lm += "Person: " + gen("person", stop="\n", max_tokens=30) + "\n"
# Extract organization
lm += "Organization: " + gen("organization", stop="\n", max_tokens=30) + "\n"
# Extract date
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}", max_tokens=10) + "\n"
# Extract location
lm += "Location: " + gen("location", stop="\n", max_tokens=30) + "\n"
return lm
text = "Tim Cook announced at Apple Park on 2024-09-15 in Cupertino."
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = extract_entities(lm, text)
print(f"Person: {lm['person']}")
print(f"Organization: {lm['organization']}")
print(f"Date: {lm['date']}")
print(f"Location: {lm['location']}")
```
## Best Practices
### 1. Use Regex for Format Validation
```python
# ✅ Good: Regex ensures valid format
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
# ❌ Bad: Free generation may produce invalid emails
lm += "Email: " + gen("email", max_tokens=50)
```
### 2. Use select() for Fixed Categories
```python
# ✅ Good: Guaranteed valid category
lm += "Status: " + select(["pending", "approved", "rejected"], name="status")
# ❌ Bad: May generate typos or invalid values
lm += "Status: " + gen("status", max_tokens=20)
```
### 3. Leverage Token Healing
```python
# Token healing is enabled by default
# No special action needed - just concatenate naturally
lm += "The capital is " + gen("capital") # Automatic healing
```
### 4. Use stop Sequences
```python
# ✅ Good: Stop at newline for single-line outputs
lm += "Name: " + gen("name", stop="\n")
# ❌ Bad: May generate multiple lines
lm += "Name: " + gen("name", max_tokens=50)
```
### 5. Create Reusable Functions
```python
# ✅ Good: Reusable pattern
@guidance
def generate_person(lm):
lm += "Name: " + gen("name", stop="\n")
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
return lm
# Use multiple times
lm = generate_person(lm)
lm += "\n\n"
lm = generate_person(lm)
```
### 6. Balance Constraints
```python
# ✅ Good: Reasonable constraints
lm += gen("name", regex=r"[A-Za-z ]+", max_tokens=30)
# ❌ Too strict: May fail or be very slow
lm += gen("name", regex=r"^(John|Jane)$", max_tokens=10)
```
## Comparison to Alternatives
| Feature | Guidance | Instructor | Outlines | LMQL |
|---------|----------|------------|----------|------|
| Regex Constraints | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes |
| Grammar Support | ✅ CFG | ❌ No | ✅ CFG | ✅ CFG |
| Pydantic Validation | ❌ No | ✅ Yes | ✅ Yes | ❌ No |
| Token Healing | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
| Local Models | ✅ Yes | ⚠️ Limited | ✅ Yes | ✅ Yes |
| API Models | ✅ Yes | ✅ Yes | ⚠️ Limited | ✅ Yes |
| Pythonic Syntax | ✅ Yes | ✅ Yes | ✅ Yes | ❌ SQL-like |
| Learning Curve | Low | Low | Medium | High |
**When to choose Guidance:**
- Need regex/grammar constraints
- Want token healing
- Building complex workflows with control flow
- Using local models (Transformers, llama.cpp)
- Prefer Pythonic syntax
**When to choose alternatives:**
- Instructor: Need Pydantic validation with automatic retrying
- Outlines: Need JSON schema validation
- LMQL: Prefer declarative query syntax
## Performance Characteristics
**Latency Reduction:**
- 30-50% faster than traditional prompting for constrained outputs
- Token healing reduces unnecessary regeneration
- Grammar constraints prevent invalid token generation
**Memory Usage:**
- Minimal overhead vs unconstrained generation
- Grammar compilation cached after first use
- Efficient token filtering at inference time
**Token Efficiency:**
- Prevents wasted tokens on invalid outputs
- No need for retry loops
- Direct path to valid outputs
## Resources
- **Documentation**: https://guidance.readthedocs.io
- **GitHub**: https://github.com/guidance-ai/guidance (18k+ stars)
- **Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
- **Discord**: Community support available
## See Also
- `references/constraints.md` - Comprehensive regex and grammar patterns
- `references/backends.md` - Backend-specific configuration
- `references/examples.md` - Production-ready examples

View file

@ -0,0 +1,554 @@
# Backend Configuration Guide
Complete guide to configuring Guidance with different LLM backends.
## Table of Contents
- API-Based Models (Anthropic, OpenAI)
- Local Models (Transformers, llama.cpp)
- Backend Comparison
- Performance Tuning
- Advanced Configuration
## API-Based Models
### Anthropic Claude
#### Basic Setup
```python
from guidance import models
# Using environment variable
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Reads ANTHROPIC_API_KEY from environment
# Explicit API key
lm = models.Anthropic(
model="claude-sonnet-4-5-20250929",
api_key="your-api-key-here"
)
```
#### Available Models
```python
# Claude 3.5 Sonnet (Latest, recommended)
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Claude 3.7 Sonnet (Fast, cost-effective)
lm = models.Anthropic("claude-sonnet-3.7-20250219")
# Claude 3 Opus (Most capable)
lm = models.Anthropic("claude-3-opus-20240229")
# Claude 3.5 Haiku (Fastest, cheapest)
lm = models.Anthropic("claude-3-5-haiku-20241022")
```
#### Configuration Options
```python
lm = models.Anthropic(
model="claude-sonnet-4-5-20250929",
api_key="your-api-key",
max_tokens=4096, # Max tokens to generate
temperature=0.7, # Sampling temperature (0-1)
top_p=0.9, # Nucleus sampling
timeout=30, # Request timeout (seconds)
max_retries=3 # Retry failed requests
)
```
#### With Context Managers
```python
from guidance import models, system, user, assistant, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
with system():
lm += "You are a helpful assistant."
with user():
lm += "What is the capital of France?"
with assistant():
lm += gen(max_tokens=50)
print(lm)
```
### OpenAI
#### Basic Setup
```python
from guidance import models
# Using environment variable
lm = models.OpenAI("gpt-4o")
# Reads OPENAI_API_KEY from environment
# Explicit API key
lm = models.OpenAI(
model="gpt-4o",
api_key="your-api-key-here"
)
```
#### Available Models
```python
# GPT-4o (Latest, multimodal)
lm = models.OpenAI("gpt-4o")
# GPT-4o Mini (Fast, cost-effective)
lm = models.OpenAI("gpt-4o-mini")
# GPT-4 Turbo
lm = models.OpenAI("gpt-4-turbo")
# GPT-3.5 Turbo (Cheapest)
lm = models.OpenAI("gpt-3.5-turbo")
```
#### Configuration Options
```python
lm = models.OpenAI(
model="gpt-4o-mini",
api_key="your-api-key",
max_tokens=2048,
temperature=0.7,
top_p=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
timeout=30
)
```
#### Chat Format
```python
from guidance import models, gen
lm = models.OpenAI("gpt-4o-mini")
# OpenAI uses chat format
lm += [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"}
]
# Generate response
lm += gen(max_tokens=50)
```
### Azure OpenAI
```python
from guidance import models
lm = models.AzureOpenAI(
model="gpt-4o",
azure_endpoint="https://your-resource.openai.azure.com/",
api_key="your-azure-api-key",
api_version="2024-02-15-preview",
deployment_name="your-deployment-name"
)
```
## Local Models
### Transformers (Hugging Face)
#### Basic Setup
```python
from guidance.models import Transformers
# Load model from Hugging Face
lm = Transformers("microsoft/Phi-4-mini-instruct")
```
#### GPU Configuration
```python
# Use GPU
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda"
)
# Use specific GPU
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda:0" # GPU 0
)
# Use CPU
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cpu"
)
```
#### Advanced Configuration
```python
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda",
torch_dtype="float16", # Use FP16 (faster, less memory)
load_in_8bit=True, # 8-bit quantization
max_memory={0: "20GB"}, # GPU memory limit
offload_folder="./offload" # Offload to disk if needed
)
```
#### Popular Models
```python
# Phi-4 (Microsoft)
lm = Transformers("microsoft/Phi-4-mini-instruct")
lm = Transformers("microsoft/Phi-3-medium-4k-instruct")
# Llama 3 (Meta)
lm = Transformers("meta-llama/Llama-3.1-8B-Instruct")
lm = Transformers("meta-llama/Llama-3.1-70B-Instruct")
# Mistral (Mistral AI)
lm = Transformers("mistralai/Mistral-7B-Instruct-v0.3")
lm = Transformers("mistralai/Mixtral-8x7B-Instruct-v0.1")
# Qwen (Alibaba)
lm = Transformers("Qwen/Qwen2.5-7B-Instruct")
# Gemma (Google)
lm = Transformers("google/gemma-2-9b-it")
```
#### Generation Configuration
```python
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda"
)
# Configure generation
from guidance import gen
result = lm + gen(
max_tokens=100,
temperature=0.7,
top_p=0.9,
top_k=50,
repetition_penalty=1.1
)
```
### llama.cpp
#### Basic Setup
```python
from guidance.models import LlamaCpp
# Load GGUF model
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096 # Context window
)
```
#### GPU Configuration
```python
# Use GPU acceleration
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096,
n_gpu_layers=35, # Offload 35 layers to GPU
n_threads=8 # CPU threads for remaining layers
)
# Full GPU offload
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096,
n_gpu_layers=-1 # Offload all layers
)
```
#### Advanced Configuration
```python
lm = LlamaCpp(
model_path="/path/to/llama-3.1-8b-instruct.Q4_K_M.gguf",
n_ctx=8192, # Context window (tokens)
n_gpu_layers=35, # GPU layers
n_threads=8, # CPU threads
n_batch=512, # Batch size for prompt processing
use_mmap=True, # Memory-map the model file
use_mlock=False, # Lock model in RAM
seed=42, # Random seed
verbose=False # Suppress verbose output
)
```
#### Quantized Models
```python
# Q4_K_M (4-bit, recommended for most cases)
lm = LlamaCpp("/path/to/model.Q4_K_M.gguf")
# Q5_K_M (5-bit, better quality)
lm = LlamaCpp("/path/to/model.Q5_K_M.gguf")
# Q8_0 (8-bit, high quality)
lm = LlamaCpp("/path/to/model.Q8_0.gguf")
# F16 (16-bit float, highest quality)
lm = LlamaCpp("/path/to/model.F16.gguf")
```
#### Popular GGUF Models
```python
# Llama 3.1
lm = LlamaCpp("llama-3.1-8b-instruct.Q4_K_M.gguf")
# Mistral
lm = LlamaCpp("mistral-7b-instruct-v0.3.Q4_K_M.gguf")
# Phi-4
lm = LlamaCpp("phi-4-mini-instruct.Q4_K_M.gguf")
```
## Backend Comparison
### Feature Matrix
| Feature | Anthropic | OpenAI | Transformers | llama.cpp |
|---------|-----------|--------|--------------|-----------|
| Constrained Generation | ✅ Full | ✅ Full | ✅ Full | ✅ Full |
| Token Healing | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
| Streaming | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
| GPU Support | N/A | N/A | ✅ Yes | ✅ Yes |
| Quantization | N/A | N/A | ✅ Yes | ✅ Yes |
| Cost | $$$ | $$$ | Free | Free |
| Latency | Low | Low | Medium | Low |
| Setup Difficulty | Easy | Easy | Medium | Medium |
### Performance Characteristics
**Anthropic Claude:**
- **Latency**: 200-500ms (API call)
- **Throughput**: Limited by API rate limits
- **Cost**: $3-15 per 1M input tokens
- **Best for**: Production systems, high-quality outputs
**OpenAI:**
- **Latency**: 200-400ms (API call)
- **Throughput**: Limited by API rate limits
- **Cost**: $0.15-30 per 1M input tokens
- **Best for**: Cost-sensitive production, gpt-4o-mini
**Transformers:**
- **Latency**: 50-200ms (local inference)
- **Throughput**: GPU-dependent (10-100 tokens/sec)
- **Cost**: Hardware cost only
- **Best for**: Privacy-sensitive, high-volume, experimentation
**llama.cpp:**
- **Latency**: 30-150ms (local inference)
- **Throughput**: Hardware-dependent (20-150 tokens/sec)
- **Cost**: Hardware cost only
- **Best for**: Edge deployment, Apple Silicon, CPU inference
### Memory Requirements
**Transformers (FP16):**
- 7B model: ~14GB GPU VRAM
- 13B model: ~26GB GPU VRAM
- 70B model: ~140GB GPU VRAM (multi-GPU)
**llama.cpp (Q4_K_M):**
- 7B model: ~4.5GB RAM
- 13B model: ~8GB RAM
- 70B model: ~40GB RAM
**Optimization Tips:**
- Use quantized models (Q4_K_M) for lower memory
- Use GPU offloading for faster inference
- Use CPU inference for smaller models (<7B)
## Performance Tuning
### API Models (Anthropic, OpenAI)
#### Reduce Latency
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Use lower max_tokens (faster response)
lm += gen(max_tokens=100) # Instead of 1000
# Use streaming (perceived latency reduction)
for chunk in lm.stream(gen(max_tokens=500)):
print(chunk, end="", flush=True)
```
#### Reduce Cost
```python
# Use cheaper models
lm = models.Anthropic("claude-3-5-haiku-20241022") # vs Sonnet
lm = models.OpenAI("gpt-4o-mini") # vs gpt-4o
# Reduce context size
# - Keep prompts concise
# - Avoid large few-shot examples
# - Use max_tokens limits
```
### Local Models (Transformers, llama.cpp)
#### Optimize GPU Usage
```python
from guidance.models import Transformers
# Use FP16 for 2x speedup
lm = Transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
torch_dtype="float16"
)
# Use 8-bit quantization for 4x memory reduction
lm = Transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
load_in_8bit=True
)
# Use flash attention (requires flash-attn package)
lm = Transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
use_flash_attention_2=True
)
```
#### Optimize llama.cpp
```python
from guidance.models import LlamaCpp
# Maximize GPU layers
lm = LlamaCpp(
model_path="/path/to/model.Q4_K_M.gguf",
n_gpu_layers=-1 # All layers on GPU
)
# Optimize batch size
lm = LlamaCpp(
model_path="/path/to/model.Q4_K_M.gguf",
n_batch=512, # Larger batch = faster prompt processing
n_gpu_layers=-1
)
# Use Metal (Apple Silicon)
lm = LlamaCpp(
model_path="/path/to/model.Q4_K_M.gguf",
n_gpu_layers=-1, # Use Metal GPU acceleration
use_mmap=True
)
```
#### Batch Processing
```python
# Process multiple requests efficiently
requests = [
"What is 2+2?",
"What is the capital of France?",
"What is photosynthesis?"
]
# Bad: Sequential processing
for req in requests:
lm = Transformers("microsoft/Phi-4-mini-instruct")
lm += req + gen(max_tokens=50)
# Good: Reuse loaded model
lm = Transformers("microsoft/Phi-4-mini-instruct")
for req in requests:
lm += req + gen(max_tokens=50)
```
## Advanced Configuration
### Custom Model Configurations
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from guidance.models import Transformers
# Load custom model
tokenizer = AutoTokenizer.from_pretrained("your-model")
model = AutoModelForCausalLM.from_pretrained(
"your-model",
device_map="auto",
torch_dtype="float16"
)
# Use with Guidance
lm = Transformers(model=model, tokenizer=tokenizer)
```
### Environment Variables
```bash
# API keys
export ANTHROPIC_API_KEY="sk-ant-..."
export OPENAI_API_KEY="sk-..."
# Transformers cache
export HF_HOME="/path/to/cache"
export TRANSFORMERS_CACHE="/path/to/cache"
# GPU selection
export CUDA_VISIBLE_DEVICES=0,1 # Use GPU 0 and 1
```
### Debugging
```python
# Enable verbose logging
import logging
logging.basicConfig(level=logging.DEBUG)
# Check backend info
lm = models.Anthropic("claude-sonnet-4-5-20250929")
print(f"Model: {lm.model_name}")
print(f"Backend: {lm.backend}")
# Check GPU usage (Transformers)
lm = Transformers("microsoft/Phi-4-mini-instruct", device="cuda")
print(f"Device: {lm.device}")
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
```
## Resources
- **Anthropic Docs**: https://docs.anthropic.com
- **OpenAI Docs**: https://platform.openai.com/docs
- **Hugging Face Models**: https://huggingface.co/models
- **llama.cpp**: https://github.com/ggerganov/llama.cpp
- **GGUF Models**: https://huggingface.co/models?library=gguf

View file

@ -0,0 +1,674 @@
# Comprehensive Constraint Patterns
Guide to regex constraints, grammar-based generation, and token healing in Guidance.
## Table of Contents
- Regex Constraints
- Grammar-Based Generation
- Token Healing
- Selection Constraints
- Complex Patterns
- Performance Optimization
## Regex Constraints
### Basic Patterns
#### Numeric Constraints
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Integer (positive)
lm += "Age: " + gen("age", regex=r"[0-9]+")
# Integer (with negatives)
lm += "Temperature: " + gen("temp", regex=r"-?[0-9]+")
# Float (positive)
lm += "Price: $" + gen("price", regex=r"[0-9]+\.[0-9]{2}")
# Float (with negatives and optional decimals)
lm += "Value: " + gen("value", regex=r"-?[0-9]+(\.[0-9]+)?")
# Percentage (0-100)
lm += "Progress: " + gen("progress", regex=r"(100|[0-9]{1,2})")
# Range (1-5 stars)
lm += "Rating: " + gen("rating", regex=r"[1-5]") + " stars"
```
#### Text Constraints
```python
# Alphabetic only
lm += "Name: " + gen("name", regex=r"[A-Za-z]+")
# Alphabetic with spaces
lm += "Full Name: " + gen("full_name", regex=r"[A-Za-z ]+")
# Alphanumeric
lm += "Username: " + gen("username", regex=r"[A-Za-z0-9_]+")
# Capitalized words
lm += "Title: " + gen("title", regex=r"[A-Z][a-z]+( [A-Z][a-z]+)*")
# Lowercase only
lm += "Code: " + gen("code", regex=r"[a-z0-9-]+")
# Specific length
lm += "ID: " + gen("id", regex=r"[A-Z]{3}-[0-9]{6}") # e.g., "ABC-123456"
```
#### Date and Time Constraints
```python
# Date (YYYY-MM-DD)
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
# Date (MM/DD/YYYY)
lm += "Date: " + gen("date_us", regex=r"\d{2}/\d{2}/\d{4}")
# Time (HH:MM)
lm += "Time: " + gen("time", regex=r"\d{2}:\d{2}")
# Time (HH:MM:SS)
lm += "Time: " + gen("time_full", regex=r"\d{2}:\d{2}:\d{2}")
# ISO 8601 datetime
lm += "Timestamp: " + gen(
"timestamp",
regex=r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z"
)
# Year (YYYY)
lm += "Year: " + gen("year", regex=r"(19|20)\d{2}")
# Month name
lm += "Month: " + gen(
"month",
regex=r"(January|February|March|April|May|June|July|August|September|October|November|December)"
)
```
#### Contact Information
```python
# Email
lm += "Email: " + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
)
# Phone (US format)
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
# Phone (international format)
lm += "Phone: " + gen("phone_intl", regex=r"\+[0-9]{1,3}-[0-9]{1,14}")
# ZIP code (US)
lm += "ZIP: " + gen("zip", regex=r"\d{5}(-\d{4})?")
# Postal code (Canada)
lm += "Postal: " + gen("postal", regex=r"[A-Z]\d[A-Z] \d[A-Z]\d")
# URL
lm += "URL: " + gen(
"url",
regex=r"https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/[a-zA-Z0-9._~:/?#\[\]@!$&'()*+,;=-]*)?"
)
```
### Advanced Patterns
#### JSON Field Constraints
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# String field with quotes
lm += '"name": ' + gen("name", regex=r'"[A-Za-z ]+"')
# Numeric field (no quotes)
lm += '"age": ' + gen("age", regex=r"[0-9]+")
# Boolean field
lm += '"active": ' + gen("active", regex=r"(true|false)")
# Null field
lm += '"optional": ' + gen("optional", regex=r"(null|[0-9]+)")
# Array of strings
lm += '"tags": [' + gen(
"tags",
regex=r'"[a-z]+"(, "[a-z]+")*'
) + ']'
# Complete JSON object
lm += """{
"name": """ + gen("name", regex=r'"[A-Za-z ]+"') + """,
"age": """ + gen("age", regex=r"[0-9]+") + """,
"email": """ + gen(
"email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + """
}"""
```
#### Code Patterns
```python
# Python variable name
lm += "Variable: " + gen("var", regex=r"[a-z_][a-z0-9_]*")
# Python function name
lm += "Function: " + gen("func", regex=r"[a-z_][a-z0-9_]*")
# Hex color code
lm += "Color: #" + gen("color", regex=r"[0-9A-Fa-f]{6}")
# UUID
lm += "UUID: " + gen(
"uuid",
regex=r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
)
# Git commit hash (short)
lm += "Commit: " + gen("commit", regex=r"[0-9a-f]{7}")
# Semantic version
lm += "Version: " + gen("version", regex=r"[0-9]+\.[0-9]+\.[0-9]+")
# IP address (IPv4)
lm += "IP: " + gen(
"ip",
regex=r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
)
```
#### Domain-Specific Patterns
```python
# Credit card number
lm += "Card: " + gen("card", regex=r"\d{4}-\d{4}-\d{4}-\d{4}")
# Social Security Number (US)
lm += "SSN: " + gen("ssn", regex=r"\d{3}-\d{2}-\d{4}")
# ISBN-13
lm += "ISBN: " + gen("isbn", regex=r"978-\d{1,5}-\d{1,7}-\d{1,7}-\d")
# License plate (US)
lm += "Plate: " + gen("plate", regex=r"[A-Z]{3}-\d{4}")
# Currency amount
lm += "Amount: $" + gen("amount", regex=r"[0-9]{1,3}(,[0-9]{3})*\.[0-9]{2}")
# Percentage with decimal
lm += "Rate: " + gen("rate", regex=r"[0-9]+\.[0-9]{1,2}%")
```
## Grammar-Based Generation
### JSON Grammar
```python
from guidance import models, gen, guidance
@guidance
def json_object(lm):
"""Generate valid JSON object."""
lm += "{\n"
# Name field (required)
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
# Age field (required)
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
# Email field (required)
lm += ' "email": ' + gen(
"email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + ",\n"
# Active field (required, boolean)
lm += ' "active": ' + gen("active", regex=r"(true|false)") + "\n"
lm += "}"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = json_object(lm)
print(lm) # Valid JSON guaranteed
```
### Nested JSON Grammar
```python
@guidance
def nested_json(lm):
"""Generate nested JSON structure."""
lm += "{\n"
# User object
lm += ' "user": {\n'
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + "\n"
lm += " },\n"
# Address object
lm += ' "address": {\n'
lm += ' "street": ' + gen("street", regex=r'"[A-Za-z0-9 ]+"') + ",\n"
lm += ' "city": ' + gen("city", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "zip": ' + gen("zip", regex=r'"\d{5}"') + "\n"
lm += " }\n"
lm += "}"
return lm
```
### Array Grammar
```python
@guidance
def json_array(lm, count=3):
"""Generate JSON array with fixed count."""
lm += "[\n"
for i in range(count):
lm += " {\n"
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + "\n"
lm += " }"
if i < count - 1:
lm += ","
lm += "\n"
lm += "]"
return lm
```
### XML Grammar
```python
@guidance
def xml_document(lm):
"""Generate valid XML document."""
lm += '<?xml version="1.0"?>\n'
lm += "<person>\n"
# Name element
lm += " <name>" + gen("name", regex=r"[A-Za-z ]+") + "</name>\n"
# Age element
lm += " <age>" + gen("age", regex=r"[0-9]+") + "</age>\n"
# Email element
lm += " <email>" + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
) + "</email>\n"
lm += "</person>"
return lm
```
### CSV Grammar
```python
@guidance
def csv_row(lm):
"""Generate CSV row."""
lm += gen("name", regex=r"[A-Za-z ]+") + ","
lm += gen("age", regex=r"[0-9]+") + ","
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
return lm
@guidance
def csv_document(lm, rows=5):
"""Generate complete CSV."""
# Header
lm += "Name,Age,Email\n"
# Rows
for i in range(rows):
lm = csv_row(lm)
if i < rows - 1:
lm += "\n"
return lm
```
## Token Healing
### How Token Healing Works
**Problem:** Tokenization creates unnatural boundaries.
```python
# Example without token healing
prompt = "The capital of France is "
# Tokenization: ["The", " capital", " of", " France", " is", " "]
# Model sees last token: " "
# First generated token might include leading space: " Paris"
# Result: "The capital of France is Paris" (double space)
```
**Solution:** Guidance backs up and regenerates the last token.
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Token healing enabled by default
lm += "The capital of France is " + gen("capital", max_tokens=5)
# Process:
# 1. Back up to token before " is "
# 2. Regenerate " is" + "capital" together
# 3. Result: "The capital of France is Paris" (correct)
```
### Token Healing Examples
#### Natural Continuations
```python
# Before token healing
lm += "The function name is get" + gen("rest")
# Might generate: "The function name is get User" (space before User)
# With token healing
lm += "The function name is get" + gen("rest")
# Generates: "The function name is getUser" (correct camelCase)
```
#### Code Generation
```python
# Function name completion
lm += "def calculate_" + gen("rest", stop="(")
# Token healing ensures smooth connection: "calculate_total"
# Variable name completion
lm += "my_" + gen("var_name", regex=r"[a-z_]+")
# Token healing ensures: "my_variable_name" (not "my_ variable_name")
```
#### Domain-Specific Terms
```python
# Medical terms
lm += "The patient has hyper" + gen("condition")
# Token healing helps: "hypertension" (not "hyper tension")
# Technical terms
lm += "Using micro" + gen("tech")
# Token healing helps: "microservices" (not "micro services")
```
### Disabling Token Healing
```python
# Disable token healing if needed (rare)
lm += gen("text", token_healing=False)
```
## Selection Constraints
### Basic Selection
```python
from guidance import models, select
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Simple selection
lm += "Status: " + select(["active", "inactive", "pending"], name="status")
# Boolean selection
lm += "Approved: " + select(["Yes", "No"], name="approved")
# Multiple choice
lm += "Answer: " + select(
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
name="answer"
)
```
### Conditional Selection
```python
from guidance import models, select, gen, guidance
@guidance
def conditional_fields(lm):
"""Generate fields conditionally based on type."""
lm += "Type: " + select(["person", "company"], name="type")
if lm["type"] == "person":
lm += "\nName: " + gen("name", regex=r"[A-Za-z ]+")
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
else:
lm += "\nCompany Name: " + gen("company", regex=r"[A-Za-z ]+")
lm += "\nEmployees: " + gen("employees", regex=r"[0-9]+")
return lm
```
### Repeated Selection
```python
@guidance
def multiple_selections(lm):
"""Select multiple items."""
lm += "Select 3 colors:\n"
colors = ["red", "blue", "green", "yellow", "purple"]
for i in range(3):
lm += f"{i+1}. " + select(colors, name=f"color_{i}") + "\n"
return lm
```
## Complex Patterns
### Pattern 1: Structured Forms
```python
@guidance
def user_form(lm):
"""Generate structured user form."""
lm += "=== User Registration ===\n\n"
# Name (alphabetic only)
lm += "Full Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Age (numeric)
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
# Email (validated format)
lm += "Email: " + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
stop="\n"
) + "\n"
# Phone (US format)
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}") + "\n"
# Account type (selection)
lm += "Account Type: " + select(
["Standard", "Premium", "Enterprise"],
name="account_type"
) + "\n"
# Active status (boolean)
lm += "Active: " + select(["Yes", "No"], name="active") + "\n"
return lm
```
### Pattern 2: Multi-Entity Extraction
```python
@guidance
def extract_entities(lm, text):
"""Extract multiple entities with constraints."""
lm += f"Text: {text}\n\n"
# Person name (alphabetic)
lm += "Person: " + gen("person", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Organization (alphanumeric with spaces)
lm += "Organization: " + gen(
"organization",
regex=r"[A-Za-z0-9 ]+",
stop="\n"
) + "\n"
# Date (YYYY-MM-DD format)
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}") + "\n"
# Location (alphabetic with spaces)
lm += "Location: " + gen("location", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Amount (currency)
lm += "Amount: $" + gen("amount", regex=r"[0-9,]+\.[0-9]{2}") + "\n"
return lm
```
### Pattern 3: Code Generation
```python
@guidance
def generate_python_function(lm):
"""Generate Python function with constraints."""
# Function name (valid Python identifier)
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
# Parameter name
lm += gen("param", regex=r"[a-z_][a-z0-9_]*") + "):\n"
# Docstring
lm += ' """' + gen("docstring", stop='"""', max_tokens=50) + '"""\n'
# Function body (constrained to valid Python)
lm += " return " + gen("return_value", stop="\n") + "\n"
return lm
```
### Pattern 4: Hierarchical Data
```python
@guidance
def org_chart(lm):
"""Generate organizational chart."""
lm += "Company: " + gen("company", regex=r"[A-Za-z ]+") + "\n\n"
# CEO
lm += "CEO: " + gen("ceo", regex=r"[A-Za-z ]+") + "\n"
# Departments
for dept in ["Engineering", "Sales", "Marketing"]:
lm += f"\n{dept} Department:\n"
lm += " Head: " + gen(f"{dept.lower()}_head", regex=r"[A-Za-z ]+") + "\n"
lm += " Size: " + gen(f"{dept.lower()}_size", regex=r"[0-9]+") + " employees\n"
return lm
```
## Performance Optimization
### Best Practices
#### 1. Use Specific Patterns
```python
# ✅ Good: Specific pattern
lm += gen("age", regex=r"[0-9]{1,3}") # Fast
# ❌ Bad: Overly broad pattern
lm += gen("age", regex=r"[0-9]+") # Slower
```
#### 2. Limit Max Tokens
```python
# ✅ Good: Reasonable limit
lm += gen("name", max_tokens=30)
# ❌ Bad: No limit
lm += gen("name") # May generate forever
```
#### 3. Use stop Sequences
```python
# ✅ Good: Stop at newline
lm += gen("line", stop="\n")
# ❌ Bad: Rely on max_tokens
lm += gen("line", max_tokens=100)
```
#### 4. Cache Compiled Grammars
```python
# Grammars are cached automatically after first use
# No manual caching needed
@guidance
def reusable_pattern(lm):
"""This grammar is compiled once and cached."""
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
return lm
# First call: compiles grammar
lm = reusable_pattern(lm)
# Subsequent calls: uses cached grammar (fast)
lm = reusable_pattern(lm)
```
#### 5. Avoid Overlapping Constraints
```python
# ✅ Good: Clear constraints
lm += gen("age", regex=r"[0-9]+", max_tokens=3)
# ❌ Bad: Conflicting constraints
lm += gen("age", regex=r"[0-9]{2}", max_tokens=10) # max_tokens unnecessary
```
### Performance Benchmarks
**Regex vs Free Generation:**
- Simple regex (digits): ~1.2x slower than free gen
- Complex regex (email): ~1.5x slower than free gen
- Grammar-based: ~2x slower than free gen
**But:**
- 100% valid outputs (vs ~70% with free gen + validation)
- No retry loops needed
- Overall faster end-to-end for structured outputs
**Optimization Tips:**
- Use regex for critical fields only
- Use `select()` for small fixed sets (fastest)
- Use `stop` sequences when possible (faster than max_tokens)
- Cache compiled grammars by reusing functions
## Resources
- **Token Healing Paper**: https://arxiv.org/abs/2306.17648
- **Guidance Docs**: https://guidance.readthedocs.io
- **GitHub**: https://github.com/guidance-ai/guidance

View file

@ -0,0 +1,767 @@
# Production-Ready Examples
Real-world examples of using Guidance for structured generation, agents, and workflows.
## Table of Contents
- JSON Generation
- Data Extraction
- Classification Systems
- Agent Systems
- Multi-Step Workflows
- Code Generation
- Production Tips
## JSON Generation
### Basic JSON
```python
from guidance import models, gen, guidance
@guidance
def generate_user(lm):
"""Generate valid user JSON."""
lm += "{\n"
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
lm += ' "email": ' + gen(
"email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + "\n"
lm += "}"
return lm
# Use it
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm += "Generate a user profile:\n"
lm = generate_user(lm)
print(lm)
# Output: Valid JSON guaranteed
```
### Nested JSON
```python
@guidance
def generate_order(lm):
"""Generate nested order JSON."""
lm += "{\n"
# Customer info
lm += ' "customer": {\n'
lm += ' "name": ' + gen("customer_name", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "email": ' + gen(
"customer_email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + "\n"
lm += " },\n"
# Order details
lm += ' "order": {\n'
lm += ' "id": ' + gen("order_id", regex=r'"ORD-[0-9]{6}"') + ",\n"
lm += ' "date": ' + gen("order_date", regex=r'"\d{4}-\d{2}-\d{2}"') + ",\n"
lm += ' "total": ' + gen("order_total", regex=r"[0-9]+\.[0-9]{2}") + "\n"
lm += " },\n"
# Status
lm += ' "status": ' + gen(
"status",
regex=r'"(pending|processing|shipped|delivered)"'
) + "\n"
lm += "}"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_order(lm)
```
### JSON Array
```python
@guidance
def generate_user_list(lm, count=3):
"""Generate JSON array of users."""
lm += "[\n"
for i in range(count):
lm += " {\n"
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "active": ' + gen(f"active_{i}", regex=r"(true|false)") + "\n"
lm += " }"
if i < count - 1:
lm += ","
lm += "\n"
lm += "]"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_user_list(lm, count=5)
```
### Dynamic JSON Schema
```python
import json
from guidance import models, gen, guidance
@guidance
def json_from_schema(lm, schema):
"""Generate JSON matching a schema."""
lm += "{\n"
fields = list(schema["properties"].items())
for i, (field_name, field_schema) in enumerate(fields):
lm += f' "{field_name}": '
# Handle different types
if field_schema["type"] == "string":
if "pattern" in field_schema:
lm += gen(field_name, regex=f'"{field_schema["pattern"]}"')
else:
lm += gen(field_name, regex=r'"[^"]+"')
elif field_schema["type"] == "number":
lm += gen(field_name, regex=r"[0-9]+(\.[0-9]+)?")
elif field_schema["type"] == "integer":
lm += gen(field_name, regex=r"[0-9]+")
elif field_schema["type"] == "boolean":
lm += gen(field_name, regex=r"(true|false)")
if i < len(fields) - 1:
lm += ","
lm += "\n"
lm += "}"
return lm
# Define schema
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"score": {"type": "number"},
"active": {"type": "boolean"}
}
}
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = json_from_schema(lm, schema)
```
## Data Extraction
### Extract from Text
```python
from guidance import models, gen, guidance, system, user, assistant
@guidance
def extract_person_info(lm, text):
"""Extract structured info from text."""
lm += f"Text: {text}\n\n"
with assistant():
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
lm += "Occupation: " + gen("occupation", regex=r"[A-Za-z ]+", stop="\n") + "\n"
lm += "Email: " + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
stop="\n"
) + "\n"
return lm
text = "John Smith is a 35-year-old software engineer. Contact: john@example.com"
lm = models.Anthropic("claude-sonnet-4-5-20250929")
with system():
lm += "You extract structured information from text."
with user():
lm = extract_person_info(lm, text)
print(f"Name: {lm['name']}")
print(f"Age: {lm['age']}")
print(f"Occupation: {lm['occupation']}")
print(f"Email: {lm['email']}")
```
### Multi-Entity Extraction
```python
@guidance
def extract_entities(lm, text):
"""Extract multiple entity types."""
lm += f"Analyze: {text}\n\n"
# Person entities
lm += "People:\n"
for i in range(3): # Up to 3 people
lm += f"- " + gen(f"person_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Organization entities
lm += "\nOrganizations:\n"
for i in range(2): # Up to 2 orgs
lm += f"- " + gen(f"org_{i}", regex=r"[A-Za-z0-9 ]+", stop="\n") + "\n"
# Dates
lm += "\nDates:\n"
for i in range(2): # Up to 2 dates
lm += f"- " + gen(f"date_{i}", regex=r"\d{4}-\d{2}-\d{2}", stop="\n") + "\n"
# Locations
lm += "\nLocations:\n"
for i in range(2): # Up to 2 locations
lm += f"- " + gen(f"location_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
return lm
text = """
Tim Cook and Satya Nadella met at Microsoft headquarters in Redmond on 2024-09-15
to discuss the collaboration between Apple and Microsoft. The meeting continued
in Cupertino on 2024-09-20.
"""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = extract_entities(lm, text)
```
### Batch Extraction
```python
@guidance
def batch_extract(lm, texts):
"""Extract from multiple texts."""
lm += "Batch Extraction Results:\n\n"
for i, text in enumerate(texts):
lm += f"=== Item {i+1} ===\n"
lm += f"Text: {text}\n"
lm += "Name: " + gen(f"name_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
lm += "Sentiment: " + gen(
f"sentiment_{i}",
regex=r"(positive|negative|neutral)",
stop="\n"
) + "\n\n"
return lm
texts = [
"Alice is happy with the product",
"Bob is disappointed with the service",
"Carol has no strong feelings either way"
]
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = batch_extract(lm, texts)
```
## Classification Systems
### Sentiment Analysis
```python
from guidance import models, select, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
text = "This product is absolutely amazing! Best purchase ever."
lm += f"Text: {text}\n\n"
lm += "Sentiment: " + select(
["positive", "negative", "neutral"],
name="sentiment"
)
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]{1,3}") + "%\n"
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=50)
print(f"Sentiment: {lm['sentiment']}")
print(f"Confidence: {lm['confidence']}%")
print(f"Reasoning: {lm['reasoning']}")
```
### Multi-Label Classification
```python
@guidance
def classify_article(lm, text):
"""Classify article with multiple labels."""
lm += f"Article: {text}\n\n"
# Primary category
lm += "Primary Category: " + select(
["Technology", "Business", "Science", "Politics", "Entertainment"],
name="primary_category"
) + "\n"
# Secondary categories (up to 3)
lm += "\nSecondary Categories:\n"
categories = ["Technology", "Business", "Science", "Politics", "Entertainment"]
for i in range(3):
lm += f"{i+1}. " + select(categories, name=f"secondary_{i}") + "\n"
# Tags
lm += "\nTags: " + gen("tags", stop="\n", max_tokens=50) + "\n"
# Target audience
lm += "Target Audience: " + select(
["General", "Expert", "Beginner"],
name="audience"
)
return lm
article = """
Apple announced new AI features in iOS 18, leveraging machine learning to improve
battery life and performance. The company's stock rose 5% following the announcement.
"""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = classify_article(lm, article)
```
### Intent Classification
```python
@guidance
def classify_intent(lm, message):
"""Classify user intent."""
lm += f"User Message: {message}\n\n"
# Intent
lm += "Intent: " + select(
["question", "complaint", "request", "feedback", "other"],
name="intent"
) + "\n"
# Urgency
lm += "Urgency: " + select(
["low", "medium", "high", "critical"],
name="urgency"
) + "\n"
# Department
lm += "Route To: " + select(
["support", "sales", "billing", "technical"],
name="department"
) + "\n"
# Sentiment
lm += "Sentiment: " + select(
["positive", "neutral", "negative"],
name="sentiment"
)
return lm
message = "My account was charged twice for the same order. Need help ASAP!"
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = classify_intent(lm, message)
print(f"Intent: {lm['intent']}")
print(f"Urgency: {lm['urgency']}")
print(f"Department: {lm['department']}")
```
## Agent Systems
### ReAct Agent
```python
from guidance import models, gen, select, guidance
@guidance(stateless=False)
def react_agent(lm, question, tools, max_rounds=5):
"""ReAct agent with tool use."""
lm += f"Question: {question}\n\n"
for round in range(max_rounds):
# Thought
lm += f"Thought {round+1}: " + gen("thought", stop="\n", max_tokens=100) + "\n"
# Action selection
lm += "Action: " + select(
list(tools.keys()) + ["answer"],
name="action"
)
if lm["action"] == "answer":
lm += "\n\nFinal Answer: " + gen("answer", max_tokens=200)
break
# Action input
lm += "\nAction Input: " + gen("action_input", stop="\n", max_tokens=100) + "\n"
# Execute tool
if lm["action"] in tools:
try:
result = tools[lm["action"]](lm["action_input"])
lm += f"Observation: {result}\n\n"
except Exception as e:
lm += f"Observation: Error - {str(e)}\n\n"
return lm
# Define tools
tools = {
"calculator": lambda expr: eval(expr),
"search": lambda query: f"Search results for '{query}': [Mock results]",
"weather": lambda city: f"Weather in {city}: Sunny, 72°F"
}
# Use agent
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = react_agent(lm, "What is (25 * 4) + 10?", tools)
print(lm["answer"])
```
### Multi-Agent System
```python
@guidance
def coordinator_agent(lm, task):
"""Coordinator that delegates to specialists."""
lm += f"Task: {task}\n\n"
# Determine which specialist to use
lm += "Specialist: " + select(
["researcher", "writer", "coder", "analyst"],
name="specialist"
) + "\n"
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=100) + "\n"
return lm
@guidance
def researcher_agent(lm, query):
"""Research specialist."""
lm += f"Research Query: {query}\n\n"
lm += "Findings:\n"
for i in range(3):
lm += f"{i+1}. " + gen(f"finding_{i}", stop="\n", max_tokens=100) + "\n"
return lm
@guidance
def writer_agent(lm, topic):
"""Writing specialist."""
lm += f"Topic: {topic}\n\n"
lm += "Title: " + gen("title", stop="\n", max_tokens=50) + "\n"
lm += "Content:\n" + gen("content", max_tokens=500)
return lm
# Coordination workflow
task = "Write an article about AI safety"
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = coordinator_agent(lm, task)
specialist = lm["specialist"]
if specialist == "researcher":
lm = researcher_agent(lm, task)
elif specialist == "writer":
lm = writer_agent(lm, task)
```
### Tool Use with Validation
```python
@guidance(stateless=False)
def validated_tool_agent(lm, question):
"""Agent with validated tool calls."""
tools = {
"add": lambda a, b: float(a) + float(b),
"multiply": lambda a, b: float(a) * float(b),
"divide": lambda a, b: float(a) / float(b) if float(b) != 0 else "Error: Division by zero"
}
lm += f"Question: {question}\n\n"
for i in range(5):
# Select tool
lm += "Tool: " + select(list(tools.keys()) + ["done"], name="tool")
if lm["tool"] == "done":
lm += "\nAnswer: " + gen("answer", max_tokens=100)
break
# Get validated numeric arguments
lm += "\nArg1: " + gen("arg1", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
lm += "Arg2: " + gen("arg2", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
# Execute
result = tools[lm["tool"]](lm["arg1"], lm["arg2"])
lm += f"Result: {result}\n\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = validated_tool_agent(lm, "What is (10 + 5) * 3?")
```
## Multi-Step Workflows
### Chain of Thought
```python
@guidance
def chain_of_thought(lm, question):
"""Multi-step reasoning with CoT."""
lm += f"Question: {question}\n\n"
# Generate reasoning steps
lm += "Let me think step by step:\n\n"
for i in range(4):
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
# Final answer
lm += "\nTherefore, the answer is: " + gen("answer", stop="\n", max_tokens=50)
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = chain_of_thought(lm, "If a train travels 60 mph for 2.5 hours, how far does it go?")
print(lm["answer"])
```
### Self-Consistency
```python
@guidance
def self_consistency(lm, question, num_samples=3):
"""Generate multiple reasoning paths and aggregate."""
lm += f"Question: {question}\n\n"
answers = []
for i in range(num_samples):
lm += f"=== Attempt {i+1} ===\n"
lm += "Reasoning: " + gen(f"reasoning_{i}", stop="\n", max_tokens=100) + "\n"
lm += "Answer: " + gen(f"answer_{i}", stop="\n", max_tokens=50) + "\n\n"
answers.append(lm[f"answer_{i}"])
# Aggregate (simple majority vote)
from collections import Counter
most_common = Counter(answers).most_common(1)[0][0]
lm += f"Final Answer (by majority): {most_common}\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = self_consistency(lm, "What is 15% of 200?")
```
### Planning and Execution
```python
@guidance
def plan_and_execute(lm, goal):
"""Plan tasks then execute them."""
lm += f"Goal: {goal}\n\n"
# Planning phase
lm += "Plan:\n"
num_steps = 4
for i in range(num_steps):
lm += f"{i+1}. " + gen(f"plan_step_{i}", stop="\n", max_tokens=100) + "\n"
# Execution phase
lm += "\nExecution:\n\n"
for i in range(num_steps):
lm += f"Step {i+1}: {lm[f'plan_step_{i}']}\n"
lm += "Status: " + select(["completed", "in-progress", "blocked"], name=f"status_{i}") + "\n"
lm += "Result: " + gen(f"result_{i}", stop="\n", max_tokens=150) + "\n\n"
# Summary
lm += "Summary: " + gen("summary", max_tokens=200)
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = plan_and_execute(lm, "Build a REST API for a blog platform")
```
## Code Generation
### Python Function
```python
@guidance
def generate_python_function(lm, description):
"""Generate Python function from description."""
lm += f"Description: {description}\n\n"
# Function signature
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
lm += gen("params", regex=r"[a-z_][a-z0-9_]*(, [a-z_][a-z0-9_]*)*") + "):\n"
# Docstring
lm += ' """' + gen("docstring", stop='"""', max_tokens=100) + '"""\n'
# Function body
lm += " " + gen("body", stop="\n", max_tokens=200) + "\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_python_function(lm, "Check if a number is prime")
print(lm)
```
### SQL Query
```python
@guidance
def generate_sql(lm, description):
"""Generate SQL query from description."""
lm += f"Description: {description}\n\n"
lm += "SQL Query:\n"
# SELECT clause
lm += "SELECT " + gen("select_clause", stop=" FROM", max_tokens=100)
# FROM clause
lm += " FROM " + gen("from_clause", stop=" WHERE", max_tokens=50)
# WHERE clause (optional)
lm += " WHERE " + gen("where_clause", stop=";", max_tokens=100) + ";"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_sql(lm, "Get all users who signed up in the last 30 days")
```
### API Endpoint
```python
@guidance
def generate_api_endpoint(lm, description):
"""Generate REST API endpoint."""
lm += f"Description: {description}\n\n"
# HTTP method
lm += "Method: " + select(["GET", "POST", "PUT", "DELETE"], name="method") + "\n"
# Path
lm += "Path: /" + gen("path", regex=r"[a-z0-9/-]+", stop="\n") + "\n"
# Request body (if POST/PUT)
if lm["method"] in ["POST", "PUT"]:
lm += "\nRequest Body:\n"
lm += "{\n"
lm += ' "field1": ' + gen("field1", regex=r'"[a-z_]+"') + ",\n"
lm += ' "field2": ' + gen("field2", regex=r'"[a-z_]+"') + "\n"
lm += "}\n"
# Response
lm += "\nResponse (200 OK):\n"
lm += "{\n"
lm += ' "status": "success",\n'
lm += ' "data": ' + gen("response_data", max_tokens=100) + "\n"
lm += "}\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_api_endpoint(lm, "Create a new blog post")
```
## Production Tips
### Error Handling
```python
@guidance
def safe_extraction(lm, text):
"""Extract with fallback handling."""
try:
lm += f"Text: {text}\n"
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n", max_tokens=30)
return lm
except Exception as e:
# Fallback to less strict extraction
lm += f"Text: {text}\n"
lm += "Name: " + gen("name", stop="\n", max_tokens=30)
return lm
```
### Caching
```python
from functools import lru_cache
@lru_cache(maxsize=100)
def cached_generation(text):
"""Cache LLM generations."""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm += f"Analyze: {text}\n"
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
return lm["sentiment"]
# First call: hits LLM
result1 = cached_generation("This is great!")
# Second call: returns cached result
result2 = cached_generation("This is great!") # Instant!
```
### Monitoring
```python
import time
@guidance
def monitored_generation(lm, text):
"""Track generation metrics."""
start_time = time.time()
lm += f"Text: {text}\n"
lm += "Analysis: " + gen("analysis", max_tokens=100)
elapsed = time.time() - start_time
# Log metrics
print(f"Generation time: {elapsed:.2f}s")
print(f"Output length: {len(lm['analysis'])} chars")
return lm
```
### Batch Processing
```python
def batch_process(texts, batch_size=10):
"""Process texts in batches."""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
results = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
for text in batch:
lm += f"Text: {text}\n"
lm += "Sentiment: " + select(
["positive", "negative", "neutral"],
name=f"sentiment_{i}"
) + "\n\n"
results.extend([lm[f"sentiment_{i}"] for i in range(len(batch))])
return results
```
## Resources
- **Guidance Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
- **Guidance Docs**: https://guidance.readthedocs.io
- **Community Examples**: https://github.com/guidance-ai/guidance/discussions

View file

@ -0,0 +1,743 @@
---
name: instructor
description: Extract structured data from LLM responses with Pydantic validation, retry failed extractions automatically, parse complex JSON with type safety, and stream partial results with Instructor - battle-tested structured output library
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [instructor, pydantic, openai, anthropic]
metadata:
hermes:
tags: [Prompt Engineering, Instructor, Structured Output, Pydantic, Data Extraction, JSON Parsing, Type Safety, Validation, Streaming, OpenAI, Anthropic]
---
# Instructor: Structured LLM Outputs
## When to Use This Skill
Use Instructor when you need to:
- **Extract structured data** from LLM responses reliably
- **Validate outputs** against Pydantic schemas automatically
- **Retry failed extractions** with automatic error handling
- **Parse complex JSON** with type safety and validation
- **Stream partial results** for real-time processing
- **Support multiple LLM providers** with consistent API
**GitHub Stars**: 15,000+ | **Battle-tested**: 100,000+ developers
## Installation
```bash
# Base installation
pip install instructor
# With specific providers
pip install "instructor[anthropic]" # Anthropic Claude
pip install "instructor[openai]" # OpenAI
pip install "instructor[all]" # All providers
```
## Quick Start
### Basic Example: Extract User Data
```python
import instructor
from pydantic import BaseModel
from anthropic import Anthropic
# Define output structure
class User(BaseModel):
name: str
age: int
email: str
# Create instructor client
client = instructor.from_anthropic(Anthropic())
# Extract structured data
user = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "John Doe is 30 years old. His email is john@example.com"
}],
response_model=User
)
print(user.name) # "John Doe"
print(user.age) # 30
print(user.email) # "john@example.com"
```
### With OpenAI
```python
from openai import OpenAI
client = instructor.from_openai(OpenAI())
user = client.chat.completions.create(
model="gpt-4o-mini",
response_model=User,
messages=[{"role": "user", "content": "Extract: Alice, 25, alice@email.com"}]
)
```
## Core Concepts
### 1. Response Models (Pydantic)
Response models define the structure and validation rules for LLM outputs.
#### Basic Model
```python
from pydantic import BaseModel, Field
class Article(BaseModel):
title: str = Field(description="Article title")
author: str = Field(description="Author name")
word_count: int = Field(description="Number of words", gt=0)
tags: list[str] = Field(description="List of relevant tags")
article = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Analyze this article: [article text]"
}],
response_model=Article
)
```
**Benefits:**
- Type safety with Python type hints
- Automatic validation (word_count > 0)
- Self-documenting with Field descriptions
- IDE autocomplete support
#### Nested Models
```python
class Address(BaseModel):
street: str
city: str
country: str
class Person(BaseModel):
name: str
age: int
address: Address # Nested model
person = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "John lives at 123 Main St, Boston, USA"
}],
response_model=Person
)
print(person.address.city) # "Boston"
```
#### Optional Fields
```python
from typing import Optional
class Product(BaseModel):
name: str
price: float
discount: Optional[float] = None # Optional
description: str = Field(default="No description") # Default value
# LLM doesn't need to provide discount or description
```
#### Enums for Constraints
```python
from enum import Enum
class Sentiment(str, Enum):
POSITIVE = "positive"
NEGATIVE = "negative"
NEUTRAL = "neutral"
class Review(BaseModel):
text: str
sentiment: Sentiment # Only these 3 values allowed
review = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "This product is amazing!"
}],
response_model=Review
)
print(review.sentiment) # Sentiment.POSITIVE
```
### 2. Validation
Pydantic validates LLM outputs automatically. If validation fails, Instructor retries.
#### Built-in Validators
```python
from pydantic import Field, EmailStr, HttpUrl
class Contact(BaseModel):
name: str = Field(min_length=2, max_length=100)
age: int = Field(ge=0, le=120) # 0 <= age <= 120
email: EmailStr # Validates email format
website: HttpUrl # Validates URL format
# If LLM provides invalid data, Instructor retries automatically
```
#### Custom Validators
```python
from pydantic import field_validator
class Event(BaseModel):
name: str
date: str
attendees: int
@field_validator('date')
def validate_date(cls, v):
"""Ensure date is in YYYY-MM-DD format."""
import re
if not re.match(r'\d{4}-\d{2}-\d{2}', v):
raise ValueError('Date must be YYYY-MM-DD format')
return v
@field_validator('attendees')
def validate_attendees(cls, v):
"""Ensure positive attendees."""
if v < 1:
raise ValueError('Must have at least 1 attendee')
return v
```
#### Model-Level Validation
```python
from pydantic import model_validator
class DateRange(BaseModel):
start_date: str
end_date: str
@model_validator(mode='after')
def check_dates(self):
"""Ensure end_date is after start_date."""
from datetime import datetime
start = datetime.strptime(self.start_date, '%Y-%m-%d')
end = datetime.strptime(self.end_date, '%Y-%m-%d')
if end < start:
raise ValueError('end_date must be after start_date')
return self
```
### 3. Automatic Retrying
Instructor retries automatically when validation fails, providing error feedback to the LLM.
```python
# Retries up to 3 times if validation fails
user = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Extract user from: John, age unknown"
}],
response_model=User,
max_retries=3 # Default is 3
)
# If age can't be extracted, Instructor tells the LLM:
# "Validation error: age - field required"
# LLM tries again with better extraction
```
**How it works:**
1. LLM generates output
2. Pydantic validates
3. If invalid: Error message sent back to LLM
4. LLM tries again with error feedback
5. Repeats up to max_retries
### 4. Streaming
Stream partial results for real-time processing.
#### Streaming Partial Objects
```python
from instructor import Partial
class Story(BaseModel):
title: str
content: str
tags: list[str]
# Stream partial updates as LLM generates
for partial_story in client.messages.create_partial(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Write a short sci-fi story"
}],
response_model=Story
):
print(f"Title: {partial_story.title}")
print(f"Content so far: {partial_story.content[:100]}...")
# Update UI in real-time
```
#### Streaming Iterables
```python
class Task(BaseModel):
title: str
priority: str
# Stream list items as they're generated
tasks = client.messages.create_iterable(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Generate 10 project tasks"
}],
response_model=Task
)
for task in tasks:
print(f"- {task.title} ({task.priority})")
# Process each task as it arrives
```
## Provider Configuration
### Anthropic Claude
```python
import instructor
from anthropic import Anthropic
client = instructor.from_anthropic(
Anthropic(api_key="your-api-key")
)
# Use with Claude models
response = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[...],
response_model=YourModel
)
```
### OpenAI
```python
from openai import OpenAI
client = instructor.from_openai(
OpenAI(api_key="your-api-key")
)
response = client.chat.completions.create(
model="gpt-4o-mini",
response_model=YourModel,
messages=[...]
)
```
### Local Models (Ollama)
```python
from openai import OpenAI
# Point to local Ollama server
client = instructor.from_openai(
OpenAI(
base_url="http://localhost:11434/v1",
api_key="ollama" # Required but ignored
),
mode=instructor.Mode.JSON
)
response = client.chat.completions.create(
model="llama3.1",
response_model=YourModel,
messages=[...]
)
```
## Common Patterns
### Pattern 1: Data Extraction from Text
```python
class CompanyInfo(BaseModel):
name: str
founded_year: int
industry: str
employees: int
headquarters: str
text = """
Tesla, Inc. was founded in 2003. It operates in the automotive and energy
industry with approximately 140,000 employees. The company is headquartered
in Austin, Texas.
"""
company = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": f"Extract company information from: {text}"
}],
response_model=CompanyInfo
)
```
### Pattern 2: Classification
```python
class Category(str, Enum):
TECHNOLOGY = "technology"
FINANCE = "finance"
HEALTHCARE = "healthcare"
EDUCATION = "education"
OTHER = "other"
class ArticleClassification(BaseModel):
category: Category
confidence: float = Field(ge=0.0, le=1.0)
keywords: list[str]
classification = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Classify this article: [article text]"
}],
response_model=ArticleClassification
)
```
### Pattern 3: Multi-Entity Extraction
```python
class Person(BaseModel):
name: str
role: str
class Organization(BaseModel):
name: str
industry: str
class Entities(BaseModel):
people: list[Person]
organizations: list[Organization]
locations: list[str]
text = "Tim Cook, CEO of Apple, announced at the event in Cupertino..."
entities = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": f"Extract all entities from: {text}"
}],
response_model=Entities
)
for person in entities.people:
print(f"{person.name} - {person.role}")
```
### Pattern 4: Structured Analysis
```python
class SentimentAnalysis(BaseModel):
overall_sentiment: Sentiment
positive_aspects: list[str]
negative_aspects: list[str]
suggestions: list[str]
score: float = Field(ge=-1.0, le=1.0)
review = "The product works well but setup was confusing..."
analysis = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": f"Analyze this review: {review}"
}],
response_model=SentimentAnalysis
)
```
### Pattern 5: Batch Processing
```python
def extract_person(text: str) -> Person:
return client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": f"Extract person from: {text}"
}],
response_model=Person
)
texts = [
"John Doe is a 30-year-old engineer",
"Jane Smith, 25, works in marketing",
"Bob Johnson, age 40, software developer"
]
people = [extract_person(text) for text in texts]
```
## Advanced Features
### Union Types
```python
from typing import Union
class TextContent(BaseModel):
type: str = "text"
content: str
class ImageContent(BaseModel):
type: str = "image"
url: HttpUrl
caption: str
class Post(BaseModel):
title: str
content: Union[TextContent, ImageContent] # Either type
# LLM chooses appropriate type based on content
```
### Dynamic Models
```python
from pydantic import create_model
# Create model at runtime
DynamicUser = create_model(
'User',
name=(str, ...),
age=(int, Field(ge=0)),
email=(EmailStr, ...)
)
user = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[...],
response_model=DynamicUser
)
```
### Custom Modes
```python
# For providers without native structured outputs
client = instructor.from_anthropic(
Anthropic(),
mode=instructor.Mode.JSON # JSON mode
)
# Available modes:
# - Mode.ANTHROPIC_TOOLS (recommended for Claude)
# - Mode.JSON (fallback)
# - Mode.TOOLS (OpenAI tools)
```
### Context Management
```python
# Single-use client
with instructor.from_anthropic(Anthropic()) as client:
result = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[...],
response_model=YourModel
)
# Client closed automatically
```
## Error Handling
### Handling Validation Errors
```python
from pydantic import ValidationError
try:
user = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[...],
response_model=User,
max_retries=3
)
except ValidationError as e:
print(f"Failed after retries: {e}")
# Handle gracefully
except Exception as e:
print(f"API error: {e}")
```
### Custom Error Messages
```python
class ValidatedUser(BaseModel):
name: str = Field(description="Full name, 2-100 characters")
age: int = Field(description="Age between 0 and 120", ge=0, le=120)
email: EmailStr = Field(description="Valid email address")
class Config:
# Custom error messages
json_schema_extra = {
"examples": [
{
"name": "John Doe",
"age": 30,
"email": "john@example.com"
}
]
}
```
## Best Practices
### 1. Clear Field Descriptions
```python
# ❌ Bad: Vague
class Product(BaseModel):
name: str
price: float
# ✅ Good: Descriptive
class Product(BaseModel):
name: str = Field(description="Product name from the text")
price: float = Field(description="Price in USD, without currency symbol")
```
### 2. Use Appropriate Validation
```python
# ✅ Good: Constrain values
class Rating(BaseModel):
score: int = Field(ge=1, le=5, description="Rating from 1 to 5 stars")
review: str = Field(min_length=10, description="Review text, at least 10 chars")
```
### 3. Provide Examples in Prompts
```python
messages = [{
"role": "user",
"content": """Extract person info from: "John, 30, engineer"
Example format:
{
"name": "John Doe",
"age": 30,
"occupation": "engineer"
}"""
}]
```
### 4. Use Enums for Fixed Categories
```python
# ✅ Good: Enum ensures valid values
class Status(str, Enum):
PENDING = "pending"
APPROVED = "approved"
REJECTED = "rejected"
class Application(BaseModel):
status: Status # LLM must choose from enum
```
### 5. Handle Missing Data Gracefully
```python
class PartialData(BaseModel):
required_field: str
optional_field: Optional[str] = None
default_field: str = "default_value"
# LLM only needs to provide required_field
```
## Comparison to Alternatives
| Feature | Instructor | Manual JSON | LangChain | DSPy |
|---------|------------|-------------|-----------|------|
| Type Safety | ✅ Yes | ❌ No | ⚠️ Partial | ✅ Yes |
| Auto Validation | ✅ Yes | ❌ No | ❌ No | ⚠️ Limited |
| Auto Retry | ✅ Yes | ❌ No | ❌ No | ✅ Yes |
| Streaming | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
| Multi-Provider | ✅ Yes | ⚠️ Manual | ✅ Yes | ✅ Yes |
| Learning Curve | Low | Low | Medium | High |
**When to choose Instructor:**
- Need structured, validated outputs
- Want type safety and IDE support
- Require automatic retries
- Building data extraction systems
**When to choose alternatives:**
- DSPy: Need prompt optimization
- LangChain: Building complex chains
- Manual: Simple, one-off extractions
## Resources
- **Documentation**: https://python.useinstructor.com
- **GitHub**: https://github.com/jxnl/instructor (15k+ stars)
- **Cookbook**: https://python.useinstructor.com/examples
- **Discord**: Community support available
## See Also
- `references/validation.md` - Advanced validation patterns
- `references/providers.md` - Provider-specific configuration
- `references/examples.md` - Real-world use cases

View file

@ -0,0 +1,107 @@
# Real-World Examples
Practical examples of using Instructor for structured data extraction.
## Data Extraction
```python
class CompanyInfo(BaseModel):
name: str
founded: int
industry: str
employees: int
text = "Apple was founded in 1976 in the technology industry with 164,000 employees."
company = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": f"Extract: {text}"}],
response_model=CompanyInfo
)
```
## Classification
```python
class Sentiment(str, Enum):
POSITIVE = "positive"
NEGATIVE = "negative"
NEUTRAL = "neutral"
class Review(BaseModel):
sentiment: Sentiment
confidence: float = Field(ge=0.0, le=1.0)
review = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": "This product is amazing!"}],
response_model=Review
)
```
## Multi-Entity Extraction
```python
class Person(BaseModel):
name: str
role: str
class Entities(BaseModel):
people: list[Person]
organizations: list[str]
locations: list[str]
entities = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": "Tim Cook, CEO of Apple, spoke in Cupertino..."}],
response_model=Entities
)
```
## Structured Analysis
```python
class Analysis(BaseModel):
summary: str
key_points: list[str]
sentiment: Sentiment
actionable_items: list[str]
analysis = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": "Analyze: [long text]"}],
response_model=Analysis
)
```
## Batch Processing
```python
texts = ["text1", "text2", "text3"]
results = [
client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": text}],
response_model=YourModel
)
for text in texts
]
```
## Streaming
```python
for partial in client.messages.create_partial(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": "Generate report..."}],
response_model=Report
):
print(f"Progress: {partial.title}")
# Update UI in real-time
```

View file

@ -0,0 +1,70 @@
# Provider Configuration
Guide to using Instructor with different LLM providers.
## Anthropic Claude
```python
import instructor
from anthropic import Anthropic
# Basic setup
client = instructor.from_anthropic(Anthropic())
# With API key
client = instructor.from_anthropic(
Anthropic(api_key="your-api-key")
)
# Recommended mode
client = instructor.from_anthropic(
Anthropic(),
mode=instructor.Mode.ANTHROPIC_TOOLS
)
# Usage
result = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": "..."}],
response_model=YourModel
)
```
## OpenAI
```python
from openai import OpenAI
client = instructor.from_openai(OpenAI())
result = client.chat.completions.create(
model="gpt-4o-mini",
response_model=YourModel,
messages=[{"role": "user", "content": "..."}]
)
```
## Local Models (Ollama)
```python
client = instructor.from_openai(
OpenAI(
base_url="http://localhost:11434/v1",
api_key="ollama"
),
mode=instructor.Mode.JSON
)
result = client.chat.completions.create(
model="llama3.1",
response_model=YourModel,
messages=[...]
)
```
## Modes
- `Mode.ANTHROPIC_TOOLS`: Recommended for Claude
- `Mode.TOOLS`: OpenAI function calling
- `Mode.JSON`: Fallback for unsupported providers

View file

@ -0,0 +1,606 @@
# Advanced Validation Patterns
Complete guide to validation in Instructor using Pydantic.
## Table of Contents
- Built-in Validators
- Custom Field Validators
- Model-Level Validation
- Complex Validation Patterns
- Error Handling
## Built-in Validators
### Numeric Constraints
```python
from pydantic import BaseModel, Field
class Product(BaseModel):
price: float = Field(gt=0, description="Price must be positive")
discount: float = Field(ge=0, le=100, description="Discount 0-100%")
quantity: int = Field(ge=1, description="At least 1 item")
rating: float = Field(ge=0.0, le=5.0, description="Rating 0-5 stars")
# If LLM provides invalid values, automatic retry with error feedback
```
**Available constraints:**
- `gt`: Greater than
- `ge`: Greater than or equal
- `lt`: Less than
- `le`: Less than or equal
- `multiple_of`: Must be multiple of this number
### String Constraints
```python
class User(BaseModel):
username: str = Field(
min_length=3,
max_length=20,
pattern=r'^[a-zA-Z0-9_]+$',
description="3-20 alphanumeric characters"
)
bio: str = Field(max_length=500, description="Bio up to 500 chars")
status: str = Field(pattern=r'^(active|inactive|pending)$')
# pattern validates against regex
```
### Email and URL Validation
```python
from pydantic import EmailStr, HttpUrl, AnyUrl
class Contact(BaseModel):
email: EmailStr # Validates email format
website: HttpUrl # Validates HTTP/HTTPS URLs
portfolio: AnyUrl # Any valid URL scheme
contact = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{
"role": "user",
"content": "Extract: john@example.com, https://example.com"
}],
response_model=Contact
)
```
### Date and DateTime Validation
```python
from datetime import date, datetime
from pydantic import Field, field_validator
class Event(BaseModel):
event_date: date # Validates date format
created_at: datetime # Validates datetime format
year: int = Field(ge=1900, le=2100)
@field_validator('event_date')
def future_date(cls, v):
"""Ensure event is in the future."""
if v < date.today():
raise ValueError('Event must be in the future')
return v
```
### List and Dict Validation
```python
class Document(BaseModel):
tags: list[str] = Field(min_length=1, max_length=10)
keywords: list[str] = Field(min_length=3, description="At least 3 keywords")
metadata: dict[str, str] = Field(description="String key-value pairs")
@field_validator('tags')
def unique_tags(cls, v):
"""Ensure tags are unique."""
if len(v) != len(set(v)):
raise ValueError('Tags must be unique')
return v
```
## Custom Field Validators
### Basic Field Validator
```python
from pydantic import field_validator
class Person(BaseModel):
name: str
age: int
@field_validator('name')
def name_must_not_be_empty(cls, v):
"""Validate name is not empty or just whitespace."""
if not v or not v.strip():
raise ValueError('Name cannot be empty')
return v.strip()
@field_validator('age')
def age_must_be_reasonable(cls, v):
"""Validate age is between 0 and 120."""
if v < 0 or v > 120:
raise ValueError('Age must be between 0 and 120')
return v
```
### Validator with Field Info
```python
from pydantic import ValidationInfo
class Article(BaseModel):
title: str
content: str
@field_validator('content')
def content_length(cls, v, info: ValidationInfo):
"""Validate content is longer than title."""
if 'title' in info.data:
title_len = len(info.data['title'])
if len(v) < title_len * 2:
raise ValueError('Content should be at least 2x title length')
return v
```
### Multiple Fields Validation
```python
class TimeRange(BaseModel):
start_time: str
end_time: str
@field_validator('start_time', 'end_time')
def valid_time_format(cls, v):
"""Validate both times are in HH:MM format."""
import re
if not re.match(r'^\d{2}:\d{2}$', v):
raise ValueError('Time must be in HH:MM format')
return v
```
### Transform and Validate
```python
class URL(BaseModel):
url: str
@field_validator('url')
def normalize_url(cls, v):
"""Add https:// if missing."""
if not v.startswith(('http://', 'https://')):
v = f'https://{v}'
return v
```
## Model-Level Validation
### Cross-Field Validation
```python
from pydantic import model_validator
class DateRange(BaseModel):
start_date: str
end_date: str
@model_validator(mode='after')
def check_dates(self):
"""Ensure end_date is after start_date."""
from datetime import datetime
start = datetime.strptime(self.start_date, '%Y-%m-%d')
end = datetime.strptime(self.end_date, '%Y-%m-%d')
if end < start:
raise ValueError('end_date must be after start_date')
return self
class PriceRange(BaseModel):
min_price: float
max_price: float
@model_validator(mode='after')
def check_price_range(self):
"""Ensure max > min."""
if self.max_price <= self.min_price:
raise ValueError('max_price must be greater than min_price')
return self
```
### Conditional Validation
```python
class Order(BaseModel):
order_type: str # "standard" or "express"
delivery_date: str
delivery_time: Optional[str] = None
@model_validator(mode='after')
def check_delivery_time(self):
"""Express orders need delivery time."""
if self.order_type == "express" and not self.delivery_time:
raise ValueError('Express orders require delivery_time')
return self
```
### Complex Business Logic
```python
class Discount(BaseModel):
code: str
percentage: float = Field(ge=0, le=100)
min_purchase: float = Field(ge=0)
max_discount: float = Field(ge=0)
@model_validator(mode='after')
def validate_discount(self):
"""Ensure discount logic is sound."""
# Max discount can't exceed percentage of min_purchase
theoretical_max = (self.percentage / 100) * self.min_purchase
if self.max_discount > theoretical_max:
self.max_discount = theoretical_max
return self
```
## Complex Validation Patterns
### Nested Model Validation
```python
class Address(BaseModel):
street: str
city: str
country: str
postal_code: str
@field_validator('postal_code')
def validate_postal_code(cls, v, info: ValidationInfo):
"""Validate postal code format based on country."""
if 'country' in info.data:
country = info.data['country']
if country == "USA":
import re
if not re.match(r'^\d{5}(-\d{4})?$', v):
raise ValueError('Invalid US postal code')
elif country == "Canada":
if not re.match(r'^[A-Z]\d[A-Z] \d[A-Z]\d$', v):
raise ValueError('Invalid Canadian postal code')
return v
class Person(BaseModel):
name: str
address: Address
# Nested validation runs automatically
```
### List of Models
```python
class Task(BaseModel):
title: str = Field(min_length=1)
priority: int = Field(ge=1, le=5)
class Project(BaseModel):
name: str
tasks: list[Task] = Field(min_length=1, description="At least 1 task")
@field_validator('tasks')
def at_least_one_high_priority(cls, v):
"""Ensure at least one task has priority >= 4."""
if not any(task.priority >= 4 for task in v):
raise ValueError('Project needs at least one high-priority task')
return v
```
### Union Type Validation
```python
from typing import Union
class TextBlock(BaseModel):
type: str = "text"
content: str = Field(min_length=1)
class ImageBlock(BaseModel):
type: str = "image"
url: HttpUrl
alt_text: str
class Page(BaseModel):
title: str
blocks: list[Union[TextBlock, ImageBlock]]
@field_validator('blocks')
def validate_block_types(cls, v):
"""Ensure first block is TextBlock."""
if v and not isinstance(v[0], TextBlock):
raise ValueError('First block must be text')
return v
```
### Dependent Fields
```python
class Subscription(BaseModel):
plan: str # "free", "pro", "enterprise"
max_users: int
features: list[str]
@model_validator(mode='after')
def validate_plan_limits(self):
"""Enforce plan-specific limits."""
limits = {
"free": {"max_users": 1, "required_features": ["basic"]},
"pro": {"max_users": 10, "required_features": ["basic", "advanced"]},
"enterprise": {"max_users": 999, "required_features": ["basic", "advanced", "premium"]}
}
if self.plan in limits:
limit = limits[self.plan]
if self.max_users > limit["max_users"]:
raise ValueError(f'{self.plan} plan limited to {limit["max_users"]} users')
for feature in limit["required_features"]:
if feature not in self.features:
raise ValueError(f'{self.plan} plan requires {feature} feature')
return self
```
## Error Handling
### Graceful Degradation
```python
class OptionalExtraction(BaseModel):
# Required fields
title: str
# Optional fields with defaults
author: Optional[str] = None
date: Optional[str] = None
tags: list[str] = Field(default_factory=list)
# LLM can succeed even if it can't extract everything
```
### Partial Validation
```python
from pydantic import ValidationError
def extract_with_fallback(text: str):
"""Try full extraction, fall back to partial."""
try:
# Try full extraction
return client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": text}],
response_model=FullModel
)
except ValidationError:
# Fall back to partial model
return client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[{"role": "user", "content": text}],
response_model=PartialModel
)
```
### Validation Error Inspection
```python
from pydantic import ValidationError
try:
result = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[...],
response_model=MyModel,
max_retries=3
)
except ValidationError as e:
# Inspect specific errors
for error in e.errors():
field = error['loc'][0]
message = error['msg']
print(f"Field '{field}' failed: {message}")
# Custom handling per field
if field == 'email':
# Handle email validation failure
pass
```
### Custom Error Messages
```python
class DetailedModel(BaseModel):
name: str = Field(
min_length=2,
max_length=100,
description="Name between 2-100 characters"
)
age: int = Field(
ge=0,
le=120,
description="Age between 0 and 120 years"
)
@field_validator('name')
def validate_name(cls, v):
"""Provide helpful error message."""
if not v.strip():
raise ValueError(
'Name cannot be empty. '
'Please provide a valid name from the text.'
)
return v
# When validation fails, LLM sees these helpful messages
```
## Validation Best Practices
### 1. Be Specific
```python
# ❌ Bad: Vague validation
class Item(BaseModel):
name: str
# ✅ Good: Specific constraints
class Item(BaseModel):
name: str = Field(
min_length=1,
max_length=200,
description="Item name, 1-200 characters"
)
```
### 2. Provide Context
```python
# ✅ Good: Explain why validation failed
@field_validator('price')
def validate_price(cls, v):
if v <= 0:
raise ValueError(
'Price must be positive. '
'Extract numeric price from text without currency symbols.'
)
return v
```
### 3. Use Enums for Fixed Sets
```python
# ❌ Bad: String validation
status: str
@field_validator('status')
def validate_status(cls, v):
if v not in ['active', 'inactive', 'pending']:
raise ValueError('Invalid status')
return v
# ✅ Good: Enum
class Status(str, Enum):
ACTIVE = "active"
INACTIVE = "inactive"
PENDING = "pending"
status: Status # Validation automatic
```
### 4. Balance Strictness
```python
# Too strict: May fail unnecessarily
class StrictModel(BaseModel):
date: str = Field(pattern=r'^\d{4}-\d{2}-\d{2}$')
# Fails if LLM uses "2024-1-5" instead of "2024-01-05"
# Better: Normalize in validator
class FlexibleModel(BaseModel):
date: str
@field_validator('date')
def normalize_date(cls, v):
from datetime import datetime
# Parse flexible formats
for fmt in ['%Y-%m-%d', '%Y/%m/%d', '%m/%d/%Y']:
try:
dt = datetime.strptime(v, fmt)
return dt.strftime('%Y-%m-%d') # Normalize
except ValueError:
continue
raise ValueError('Invalid date format')
```
### 5. Test Validation
```python
# Test your validators with edge cases
def test_validation():
# Should succeed
valid = MyModel(field="valid_value")
# Should fail
try:
invalid = MyModel(field="invalid")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
# Run tests before using in production
```
## Advanced Techniques
### Conditional Required Fields
```python
from typing import Optional
class ConditionalModel(BaseModel):
type: str
detail_a: Optional[str] = None
detail_b: Optional[str] = None
@model_validator(mode='after')
def check_required_details(self):
"""Require different fields based on type."""
if self.type == "type_a" and not self.detail_a:
raise ValueError('type_a requires detail_a')
if self.type == "type_b" and not self.detail_b:
raise ValueError('type_b requires detail_b')
return self
```
### Validation with External Data
```python
class Product(BaseModel):
sku: str
name: str
@field_validator('sku')
def validate_sku(cls, v):
"""Check SKU exists in database."""
# Query database or API
if not database.sku_exists(v):
raise ValueError(f'SKU {v} not found in catalog')
return v
```
### Progressive Validation
```python
# Start with loose validation
class Stage1(BaseModel):
data: str # Any string
# Then strict validation
class Stage2(BaseModel):
data: str = Field(pattern=r'^[A-Z]{3}-\d{6}$')
# Use Stage1 for initial extraction
# Use Stage2 for final validation
```
## Resources
- **Pydantic Docs**: https://docs.pydantic.dev/latest/concepts/validators/
- **Instructor Examples**: https://python.useinstructor.com/examples

View file

@ -0,0 +1,261 @@
---
name: llama-cpp
description: Runs LLM inference on CPU, Apple Silicon, and consumer GPUs without NVIDIA hardware. Use for edge deployment, M1/M2/M3 Macs, AMD/Intel GPUs, or when CUDA is unavailable. Supports GGUF quantization (1.5-8 bit) for reduced memory and 4-10× speedup vs PyTorch on CPU.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [llama-cpp-python]
metadata:
hermes:
tags: [Inference Serving, Llama.cpp, CPU Inference, Apple Silicon, Edge Deployment, GGUF, Quantization, Non-NVIDIA, AMD GPUs, Intel GPUs, Embedded]
---
# llama.cpp
Pure C/C++ LLM inference with minimal dependencies, optimized for CPUs and non-NVIDIA hardware.
## When to use llama.cpp
**Use llama.cpp when:**
- Running on CPU-only machines
- Deploying on Apple Silicon (M1/M2/M3/M4)
- Using AMD or Intel GPUs (no CUDA)
- Edge deployment (Raspberry Pi, embedded systems)
- Need simple deployment without Docker/Python
**Use TensorRT-LLM instead when:**
- Have NVIDIA GPUs (A100/H100)
- Need maximum throughput (100K+ tok/s)
- Running in datacenter with CUDA
**Use vLLM instead when:**
- Have NVIDIA GPUs
- Need Python-first API
- Want PagedAttention
## Quick start
### Installation
```bash
# macOS/Linux
brew install llama.cpp
# Or build from source
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
make
# With Metal (Apple Silicon)
make LLAMA_METAL=1
# With CUDA (NVIDIA)
make LLAMA_CUDA=1
# With ROCm (AMD)
make LLAMA_HIP=1
```
### Download model
```bash
# Download from HuggingFace (GGUF format)
huggingface-cli download \
TheBloke/Llama-2-7B-Chat-GGUF \
llama-2-7b-chat.Q4_K_M.gguf \
--local-dir models/
# Or convert from HuggingFace
python convert_hf_to_gguf.py models/llama-2-7b-chat/
```
### Run inference
```bash
# Simple chat
./llama-cli \
-m models/llama-2-7b-chat.Q4_K_M.gguf \
-p "Explain quantum computing" \
-n 256 # Max tokens
# Interactive chat
./llama-cli \
-m models/llama-2-7b-chat.Q4_K_M.gguf \
--interactive
```
### Server mode
```bash
# Start OpenAI-compatible server
./llama-server \
-m models/llama-2-7b-chat.Q4_K_M.gguf \
--host 0.0.0.0 \
--port 8080 \
-ngl 32 # Offload 32 layers to GPU
# Client request
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama-2-7b-chat",
"messages": [{"role": "user", "content": "Hello!"}],
"temperature": 0.7,
"max_tokens": 100
}'
```
## Quantization formats
### GGUF format overview
| Format | Bits | Size (7B) | Speed | Quality | Use Case |
|--------|------|-----------|-------|---------|----------|
| **Q4_K_M** | 4.5 | 4.1 GB | Fast | Good | **Recommended default** |
| Q4_K_S | 4.3 | 3.9 GB | Faster | Lower | Speed critical |
| Q5_K_M | 5.5 | 4.8 GB | Medium | Better | Quality critical |
| Q6_K | 6.5 | 5.5 GB | Slower | Best | Maximum quality |
| Q8_0 | 8.0 | 7.0 GB | Slow | Excellent | Minimal degradation |
| Q2_K | 2.5 | 2.7 GB | Fastest | Poor | Testing only |
### Choosing quantization
```bash
# General use (balanced)
Q4_K_M # 4-bit, medium quality
# Maximum speed (more degradation)
Q2_K or Q3_K_M
# Maximum quality (slower)
Q6_K or Q8_0
# Very large models (70B, 405B)
Q3_K_M or Q4_K_S # Lower bits to fit in memory
```
## Hardware acceleration
### Apple Silicon (Metal)
```bash
# Build with Metal
make LLAMA_METAL=1
# Run with GPU acceleration (automatic)
./llama-cli -m model.gguf -ngl 999 # Offload all layers
# Performance: M3 Max 40-60 tokens/sec (Llama 2-7B Q4_K_M)
```
### NVIDIA GPUs (CUDA)
```bash
# Build with CUDA
make LLAMA_CUDA=1
# Offload layers to GPU
./llama-cli -m model.gguf -ngl 35 # Offload 35/40 layers
# Hybrid CPU+GPU for large models
./llama-cli -m llama-70b.Q4_K_M.gguf -ngl 20 # GPU: 20 layers, CPU: rest
```
### AMD GPUs (ROCm)
```bash
# Build with ROCm
make LLAMA_HIP=1
# Run with AMD GPU
./llama-cli -m model.gguf -ngl 999
```
## Common patterns
### Batch processing
```bash
# Process multiple prompts from file
cat prompts.txt | ./llama-cli \
-m model.gguf \
--batch-size 512 \
-n 100
```
### Constrained generation
```bash
# JSON output with grammar
./llama-cli \
-m model.gguf \
-p "Generate a person: " \
--grammar-file grammars/json.gbnf
# Outputs valid JSON only
```
### Context size
```bash
# Increase context (default 512)
./llama-cli \
-m model.gguf \
-c 4096 # 4K context window
# Very long context (if model supports)
./llama-cli -m model.gguf -c 32768 # 32K context
```
## Performance benchmarks
### CPU performance (Llama 2-7B Q4_K_M)
| CPU | Threads | Speed | Cost |
|-----|---------|-------|------|
| Apple M3 Max | 16 | 50 tok/s | $0 (local) |
| AMD Ryzen 9 7950X | 32 | 35 tok/s | $0.50/hour |
| Intel i9-13900K | 32 | 30 tok/s | $0.40/hour |
| AWS c7i.16xlarge | 64 | 40 tok/s | $2.88/hour |
### GPU acceleration (Llama 2-7B Q4_K_M)
| GPU | Speed | vs CPU | Cost |
|-----|-------|--------|------|
| NVIDIA RTX 4090 | 120 tok/s | 3-4× | $0 (local) |
| NVIDIA A10 | 80 tok/s | 2-3× | $1.00/hour |
| AMD MI250 | 70 tok/s | 2× | $2.00/hour |
| Apple M3 Max (Metal) | 50 tok/s | ~Same | $0 (local) |
## Supported models
**LLaMA family**:
- Llama 2 (7B, 13B, 70B)
- Llama 3 (8B, 70B, 405B)
- Code Llama
**Mistral family**:
- Mistral 7B
- Mixtral 8x7B, 8x22B
**Other**:
- Falcon, BLOOM, GPT-J
- Phi-3, Gemma, Qwen
- LLaVA (vision), Whisper (audio)
**Find models**: https://huggingface.co/models?library=gguf
## References
- **[Quantization Guide](references/quantization.md)** - GGUF formats, conversion, quality comparison
- **[Server Deployment](references/server.md)** - API endpoints, Docker, monitoring
- **[Optimization](references/optimization.md)** - Performance tuning, hybrid CPU+GPU
## Resources
- **GitHub**: https://github.com/ggerganov/llama.cpp
- **Models**: https://huggingface.co/models?library=gguf
- **Discord**: https://discord.gg/llama-cpp

View file

@ -0,0 +1,89 @@
# Performance Optimization Guide
Maximize llama.cpp inference speed and efficiency.
## CPU Optimization
### Thread tuning
```bash
# Set threads (default: physical cores)
./llama-cli -m model.gguf -t 8
# For AMD Ryzen 9 7950X (16 cores, 32 threads)
-t 16 # Best: physical cores
# Avoid hyperthreading (slower for matrix ops)
```
### BLAS acceleration
```bash
# OpenBLAS (faster matrix ops)
make LLAMA_OPENBLAS=1
# BLAS gives 2-3× speedup
```
## GPU Offloading
### Layer offloading
```bash
# Offload 35 layers to GPU (hybrid mode)
./llama-cli -m model.gguf -ngl 35
# Offload all layers
./llama-cli -m model.gguf -ngl 999
# Find optimal value:
# Start with -ngl 999
# If OOM, reduce by 5 until fits
```
### Memory usage
```bash
# Check VRAM usage
nvidia-smi dmon
# Reduce context if needed
./llama-cli -m model.gguf -c 2048 # 2K context instead of 4K
```
## Batch Processing
```bash
# Increase batch size for throughput
./llama-cli -m model.gguf -b 512 # Default: 512
# Physical batch (GPU)
--ubatch 128 # Process 128 tokens at once
```
## Context Management
```bash
# Default context (512 tokens)
-c 512
# Longer context (slower, more memory)
-c 4096
# Very long context (if model supports)
-c 32768
```
## Benchmarks
### CPU Performance (Llama 2-7B Q4_K_M)
| Setup | Speed | Notes |
|-------|-------|-------|
| Apple M3 Max | 50 tok/s | Metal acceleration |
| AMD 7950X (16c) | 35 tok/s | OpenBLAS |
| Intel i9-13900K | 30 tok/s | AVX2 |
### GPU Offloading (RTX 4090)
| Layers GPU | Speed | VRAM |
|------------|-------|------|
| 0 (CPU only) | 30 tok/s | 0 GB |
| 20 (hybrid) | 80 tok/s | 8 GB |
| 35 (all) | 120 tok/s | 12 GB |

View file

@ -0,0 +1,213 @@
# GGUF Quantization Guide
Complete guide to GGUF quantization formats and model conversion.
## Quantization Overview
**GGUF** (GPT-Generated Unified Format) - Standard format for llama.cpp models.
### Format Comparison
| Format | Perplexity | Size (7B) | Tokens/sec | Notes |
|--------|------------|-----------|------------|-------|
| FP16 | 5.9565 (baseline) | 13.0 GB | 15 tok/s | Original quality |
| Q8_0 | 5.9584 (+0.03%) | 7.0 GB | 25 tok/s | Nearly lossless |
| **Q6_K** | 5.9642 (+0.13%) | 5.5 GB | 30 tok/s | Best quality/size |
| **Q5_K_M** | 5.9796 (+0.39%) | 4.8 GB | 35 tok/s | Balanced |
| **Q4_K_M** | 6.0565 (+1.68%) | 4.1 GB | 40 tok/s | **Recommended** |
| Q4_K_S | 6.1125 (+2.62%) | 3.9 GB | 42 tok/s | Faster, lower quality |
| Q3_K_M | 6.3184 (+6.07%) | 3.3 GB | 45 tok/s | Small models only |
| Q2_K | 6.8673 (+15.3%) | 2.7 GB | 50 tok/s | Not recommended |
**Recommendation**: Use **Q4_K_M** for best balance of quality and speed.
## Converting Models
### HuggingFace to GGUF
```bash
# 1. Download HuggingFace model
huggingface-cli download meta-llama/Llama-2-7b-chat-hf \
--local-dir models/llama-2-7b-chat/
# 2. Convert to FP16 GGUF
python convert_hf_to_gguf.py \
models/llama-2-7b-chat/ \
--outtype f16 \
--outfile models/llama-2-7b-chat-f16.gguf
# 3. Quantize to Q4_K_M
./llama-quantize \
models/llama-2-7b-chat-f16.gguf \
models/llama-2-7b-chat-Q4_K_M.gguf \
Q4_K_M
```
### Batch quantization
```bash
# Quantize to multiple formats
for quant in Q4_K_M Q5_K_M Q6_K Q8_0; do
./llama-quantize \
model-f16.gguf \
model-${quant}.gguf \
$quant
done
```
## K-Quantization Methods
**K-quants** use mixed precision for better quality:
- Attention weights: Higher precision
- Feed-forward weights: Lower precision
**Variants**:
- `_S` (Small): Faster, lower quality
- `_M` (Medium): Balanced (recommended)
- `_L` (Large): Better quality, larger size
**Example**: `Q4_K_M`
- `Q4`: 4-bit quantization
- `K`: Mixed precision method
- `M`: Medium quality
## Quality Testing
```bash
# Calculate perplexity (quality metric)
./llama-perplexity \
-m model.gguf \
-f wikitext-2-raw/wiki.test.raw \
-c 512
# Lower perplexity = better quality
# Baseline (FP16): ~5.96
# Q4_K_M: ~6.06 (+1.7%)
# Q2_K: ~6.87 (+15.3% - too much degradation)
```
## Use Case Guide
### General purpose (chatbots, assistants)
```
Q4_K_M - Best balance
Q5_K_M - If you have extra RAM
```
### Code generation
```
Q5_K_M or Q6_K - Higher precision helps with code
```
### Creative writing
```
Q4_K_M - Sufficient quality
Q3_K_M - Acceptable for draft generation
```
### Technical/medical
```
Q6_K or Q8_0 - Maximum accuracy
```
### Edge devices (Raspberry Pi)
```
Q2_K or Q3_K_S - Fit in limited RAM
```
## Model Size Scaling
### 7B parameter models
| Format | Size | RAM needed |
|--------|------|------------|
| Q2_K | 2.7 GB | 5 GB |
| Q3_K_M | 3.3 GB | 6 GB |
| Q4_K_M | 4.1 GB | 7 GB |
| Q5_K_M | 4.8 GB | 8 GB |
| Q6_K | 5.5 GB | 9 GB |
| Q8_0 | 7.0 GB | 11 GB |
### 13B parameter models
| Format | Size | RAM needed |
|--------|------|------------|
| Q2_K | 5.1 GB | 8 GB |
| Q3_K_M | 6.2 GB | 10 GB |
| Q4_K_M | 7.9 GB | 12 GB |
| Q5_K_M | 9.2 GB | 14 GB |
| Q6_K | 10.7 GB | 16 GB |
### 70B parameter models
| Format | Size | RAM needed |
|--------|------|------------|
| Q2_K | 26 GB | 32 GB |
| Q3_K_M | 32 GB | 40 GB |
| Q4_K_M | 41 GB | 48 GB |
| Q4_K_S | 39 GB | 46 GB |
| Q5_K_M | 48 GB | 56 GB |
**Recommendation for 70B**: Use Q3_K_M or Q4_K_S to fit in consumer hardware.
## Finding Pre-Quantized Models
**TheBloke** on HuggingFace:
- https://huggingface.co/TheBloke
- Most models available in all GGUF formats
- No conversion needed
**Example**:
```bash
# Download pre-quantized Llama 2-7B
huggingface-cli download \
TheBloke/Llama-2-7B-Chat-GGUF \
llama-2-7b-chat.Q4_K_M.gguf \
--local-dir models/
```
## Importance Matrices (imatrix)
**What**: Calibration data to improve quantization quality.
**Benefits**:
- 10-20% perplexity improvement with Q4
- Essential for Q3 and below
**Usage**:
```bash
# 1. Generate importance matrix
./llama-imatrix \
-m model-f16.gguf \
-f calibration-data.txt \
-o model.imatrix
# 2. Quantize with imatrix
./llama-quantize \
--imatrix model.imatrix \
model-f16.gguf \
model-Q4_K_M.gguf \
Q4_K_M
```
**Calibration data**:
- Use domain-specific text (e.g., code for code models)
- ~100MB of representative text
- Higher quality data = better quantization
## Troubleshooting
**Model outputs gibberish**:
- Quantization too aggressive (Q2_K)
- Try Q4_K_M or Q5_K_M
- Verify model converted correctly
**Out of memory**:
- Use lower quantization (Q4_K_S instead of Q5_K_M)
- Offload fewer layers to GPU (`-ngl`)
- Use smaller context (`-c 2048`)
**Slow inference**:
- Higher quantization uses more compute
- Q8_0 much slower than Q4_K_M
- Consider speed vs quality trade-off

View file

@ -0,0 +1,125 @@
# Server Deployment Guide
Production deployment of llama.cpp server with OpenAI-compatible API.
## Server Modes
### llama-server
```bash
# Basic server
./llama-server \
-m models/llama-2-7b-chat.Q4_K_M.gguf \
--host 0.0.0.0 \
--port 8080 \
-c 4096 # Context size
# With GPU acceleration
./llama-server \
-m models/llama-2-70b.Q4_K_M.gguf \
-ngl 40 # Offload 40 layers to GPU
```
## OpenAI-Compatible API
### Chat completions
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama-2",
"messages": [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"}
],
"temperature": 0.7,
"max_tokens": 100
}'
```
### Streaming
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama-2",
"messages": [{"role": "user", "content": "Count to 10"}],
"stream": true
}'
```
## Docker Deployment
**Dockerfile**:
```dockerfile
FROM ubuntu:22.04
RUN apt-get update && apt-get install -y git build-essential
RUN git clone https://github.com/ggerganov/llama.cpp
WORKDIR /llama.cpp
RUN make LLAMA_CUDA=1
COPY models/ /models/
EXPOSE 8080
CMD ["./llama-server", "-m", "/models/model.gguf", "--host", "0.0.0.0", "--port", "8080"]
```
**Run**:
```bash
docker run --gpus all -p 8080:8080 llama-cpp:latest
```
## Monitoring
```bash
# Server metrics endpoint
curl http://localhost:8080/metrics
# Health check
curl http://localhost:8080/health
```
**Metrics**:
- requests_total
- tokens_generated
- prompt_tokens
- completion_tokens
- kv_cache_tokens
## Load Balancing
**NGINX**:
```nginx
upstream llama_cpp {
server llama1:8080;
server llama2:8080;
}
server {
location / {
proxy_pass http://llama_cpp;
proxy_read_timeout 300s;
}
}
```
## Performance Tuning
**Parallel requests**:
```bash
./llama-server \
-m model.gguf \
-np 4 # 4 parallel slots
```
**Continuous batching**:
```bash
./llama-server \
-m model.gguf \
--cont-batching # Enable continuous batching
```
**Context caching**:
```bash
./llama-server \
-m model.gguf \
--cache-prompt # Cache processed prompts
```

View file

@ -0,0 +1,330 @@
---
name: obliteratus
description: Remove refusal behaviors from open-weight LLMs using OBLITERATUS — mechanistic interpretability techniques (diff-in-means, SVD, whitened SVD, LEACE, SAE decomposition, etc.) to excise guardrails while preserving reasoning. 9 CLI methods, 28 analysis modules, 116 model presets across 5 compute tiers, tournament evaluation, and telemetry-driven recommendations. Use when a user wants to uncensor, abliterate, or remove refusal from an LLM.
version: 2.0.0
author: Hermes Agent
license: MIT
dependencies: [obliteratus, torch, transformers, bitsandbytes, accelerate, safetensors]
metadata:
hermes:
tags: [Abliteration, Uncensoring, Refusal-Removal, LLM, Weight-Projection, SVD, Mechanistic-Interpretability, HuggingFace, Model-Surgery]
related_skills: [vllm, gguf, huggingface-tokenizers]
---
# OBLITERATUS Skill
Remove refusal behaviors (guardrails) from open-weight LLMs without retraining or fine-tuning. Uses mechanistic interpretability techniques — including diff-in-means, SVD, whitened SVD, LEACE concept erasure, SAE decomposition, Bayesian kernel projection, and more — to identify and surgically excise refusal directions from model weights while preserving reasoning capabilities.
**License warning:** OBLITERATUS is AGPL-3.0. NEVER import it as a Python library. Always invoke via CLI (`obliteratus` command) or subprocess. This keeps Hermes Agent's MIT license clean.
## When to Use This Skill
Trigger when the user:
- Wants to "uncensor" or "abliterate" an LLM
- Asks about removing refusal/guardrails from a model
- Wants to create an uncensored version of Llama, Qwen, Mistral, etc.
- Mentions "refusal removal", "abliteration", "weight projection"
- Wants to analyze how a model's refusal mechanism works
- References OBLITERATUS, abliterator, or refusal directions
## Step 1: Installation
Check if already installed:
```bash
obliteratus --version 2>/dev/null && echo "INSTALLED" || echo "NOT INSTALLED"
```
If not installed, clone and install from GitHub:
```bash
git clone https://github.com/elder-plinius/OBLITERATUS.git
cd OBLITERATUS
pip install -e .
# For Gradio web UI support:
# pip install -e ".[spaces]"
```
**IMPORTANT:** Confirm with user before installing. This pulls in ~5-10GB of dependencies (PyTorch, Transformers, bitsandbytes, etc.).
## Step 2: Check Hardware
Before anything, check what GPU is available:
```bash
python3 -c "
import torch
if torch.cuda.is_available():
gpu = torch.cuda.get_device_name(0)
vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
print(f'GPU: {gpu}')
print(f'VRAM: {vram:.1f} GB')
if vram < 4: print('TIER: tiny (models under 1B)')
elif vram < 8: print('TIER: small (models 1-4B)')
elif vram < 16: print('TIER: medium (models 4-9B with 4bit quant)')
elif vram < 32: print('TIER: large (models 8-32B with 4bit quant)')
else: print('TIER: frontier (models 32B+)')
else:
print('NO GPU - only tiny models (under 1B) on CPU')
"
```
### VRAM Requirements (with 4-bit quantization)
| VRAM | Max Model Size | Example Models |
|:---------|:----------------|:--------------------------------------------|
| CPU only | ~1B params | GPT-2, TinyLlama, SmolLM |
| 4-8 GB | ~4B params | Qwen2.5-1.5B, Phi-3.5 mini, Llama 3.2 3B |
| 8-16 GB | ~9B params | Llama 3.1 8B, Mistral 7B, Gemma 2 9B |
| 24 GB | ~32B params | Qwen3-32B, Llama 3.1 70B (tight), Command-R |
| 48 GB+ | ~72B+ params | Qwen2.5-72B, DeepSeek-R1 |
| Multi-GPU| 200B+ params | Llama 3.1 405B, DeepSeek-V3 (685B MoE) |
## Step 3: Browse Available Models & Get Recommendations
```bash
# Browse models by compute tier
obliteratus models --tier medium
# Get architecture info for a specific model
obliteratus info <model_name>
# Get telemetry-driven recommendation for best method & params
obliteratus recommend <model_name>
obliteratus recommend <model_name> --insights # global cross-architecture rankings
```
## Step 4: Choose a Method
### Method Selection Guide
**Default / recommended for most cases: `advanced`.** It uses multi-direction SVD with norm-preserving projection and is well-tested.
| Situation | Recommended Method | Why |
|:----------------------------------|:-------------------|:-----------------------------------------|
| Default / most models | `advanced` | Multi-direction SVD, norm-preserving, reliable |
| Quick test / prototyping | `basic` | Fast, simple, good enough to evaluate |
| Dense model (Llama, Mistral) | `advanced` | Multi-direction, norm-preserving |
| MoE model (DeepSeek, Mixtral) | `nuclear` | Expert-granular, handles MoE complexity |
| Reasoning model (R1 distills) | `surgical` | CoT-aware, preserves chain-of-thought |
| Stubborn refusals persist | `aggressive` | Whitened SVD + head surgery + jailbreak |
| Want reversible changes | Use steering vectors (see Analysis section) |
| Maximum quality, time no object | `optimized` | Bayesian search for best parameters |
| Experimental auto-detection | `informed` | Auto-detects alignment type — experimental, may not always outperform advanced |
### 9 CLI Methods
- **basic** — Single refusal direction via diff-in-means. Fast (~5-10 min for 8B).
- **advanced** (DEFAULT, RECOMMENDED) — Multiple SVD directions, norm-preserving projection, 2 refinement passes. Medium speed (~10-20 min).
- **aggressive** — Whitened SVD + jailbreak-contrastive + attention head surgery. Higher risk of coherence damage.
- **spectral_cascade** — DCT frequency-domain decomposition. Research/novel approach.
- **informed** — Runs analysis DURING abliteration to auto-configure. Experimental — slower and less predictable than advanced.
- **surgical** — SAE features + neuron masking + head surgery + per-expert. Very slow (~1-2 hrs). Best for reasoning models.
- **optimized** — Bayesian hyperparameter search (Optuna TPE). Longest runtime but finds optimal parameters.
- **inverted** — Flips the refusal direction. Model becomes actively willing.
- **nuclear** — Maximum force combo for stubborn MoE models. Expert-granular.
### Direction Extraction Methods (--direction-method flag)
- **diff_means** (default) — Simple difference-in-means between refused/complied activations. Robust.
- **svd** — Multi-direction SVD extraction. Better for complex alignment.
- **leace** — LEACE (Linear Erasure via Closed-form Estimation). Optimal linear erasure.
### 4 Python-API-Only Methods
(NOT available via CLI — require Python import, which violates AGPL boundary. Mention to user only if they explicitly want to use OBLITERATUS as a library in their own AGPL project.)
- failspy, gabliteration, heretic, rdo
## Step 5: Run Abliteration
### Standard usage
```bash
# Default method (advanced) — recommended for most models
obliteratus obliterate <model_name> --method advanced --output-dir ./abliterated-models
# With 4-bit quantization (saves VRAM)
obliteratus obliterate <model_name> --method advanced --quantization 4bit --output-dir ./abliterated-models
# Large models (70B+) — conservative defaults
obliteratus obliterate <model_name> --method advanced --quantization 4bit --large-model --output-dir ./abliterated-models
```
### Fine-tuning parameters
```bash
obliteratus obliterate <model_name> \
--method advanced \
--direction-method diff_means \
--n-directions 4 \
--refinement-passes 2 \
--regularization 0.1 \
--quantization 4bit \
--output-dir ./abliterated-models \
--contribute # opt-in telemetry for community research
```
### Key flags
| Flag | Description | Default |
|:-----|:------------|:--------|
| `--method` | Abliteration method | advanced |
| `--direction-method` | Direction extraction | diff_means |
| `--n-directions` | Number of refusal directions (1-32) | method-dependent |
| `--refinement-passes` | Iterative passes (1-5) | 2 |
| `--regularization` | Regularization strength (0.0-1.0) | 0.1 |
| `--quantization` | Load in 4bit or 8bit | none (full precision) |
| `--large-model` | Conservative defaults for 120B+ | false |
| `--output-dir` | Where to save the abliterated model | ./obliterated_model |
| `--contribute` | Share anonymized results for research | false |
| `--verify-sample-size` | Number of test prompts for refusal check | 20 |
| `--dtype` | Model dtype (float16, bfloat16) | auto |
### Other execution modes
```bash
# Interactive guided mode (hardware → model → preset)
obliteratus interactive
# Web UI (Gradio)
obliteratus ui --port 7860
# Run a full ablation study from YAML config
obliteratus run config.yaml --preset quick
# Tournament: pit all methods against each other
obliteratus tourney <model_name>
```
## Step 6: Verify Results
After abliteration, check the output metrics:
| Metric | Good Value | Warning |
|:-------|:-----------|:--------|
| Refusal rate | < 5% (ideally ~0%) | > 10% means refusals persist |
| Perplexity change | < 10% increase | > 15% means coherence damage |
| KL divergence | < 0.1 | > 0.5 means significant distribution shift |
| Coherence | High / passes qualitative check | Degraded responses, repetition |
### If refusals persist (> 10%)
1. Try `aggressive` method
2. Increase `--n-directions` (e.g., 8 or 16)
3. Add `--refinement-passes 3`
4. Try `--direction-method svd` instead of diff_means
### If coherence is damaged (perplexity > 15% increase)
1. Reduce `--n-directions` (try 2)
2. Increase `--regularization` (try 0.3)
3. Reduce `--refinement-passes` to 1
4. Try `basic` method (gentler)
## Step 7: Use the Abliterated Model
The output is a standard HuggingFace model directory.
```bash
# Test locally with transformers
python3 -c "
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained('./abliterated-models/<model>')
tokenizer = AutoTokenizer.from_pretrained('./abliterated-models/<model>')
inputs = tokenizer('How do I pick a lock?', return_tensors='pt')
outputs = model.generate(**inputs, max_new_tokens=200)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
"
# Upload to HuggingFace Hub
huggingface-cli upload <username>/<model-name>-abliterated ./abliterated-models/<model>
# Serve with vLLM
vllm serve ./abliterated-models/<model>
```
## CLI Command Reference
| Command | Description |
|:--------|:------------|
| `obliteratus obliterate` | Main abliteration command |
| `obliteratus info <model>` | Print model architecture details |
| `obliteratus models --tier <tier>` | Browse curated models by compute tier |
| `obliteratus recommend <model>` | Telemetry-driven method/param suggestion |
| `obliteratus interactive` | Guided setup wizard |
| `obliteratus tourney <model>` | Tournament: all methods head-to-head |
| `obliteratus run <config.yaml>` | Execute ablation study from YAML |
| `obliteratus strategies` | List all registered ablation strategies |
| `obliteratus report <results.json>` | Regenerate visual reports |
| `obliteratus ui` | Launch Gradio web interface |
| `obliteratus aggregate` | Summarize community telemetry data |
## Analysis Modules
OBLITERATUS includes 28 analysis modules for mechanistic interpretability.
See `skill_view(name="obliteratus", file_path="references/analysis-modules.md")` for the full reference.
### Quick analysis commands
```bash
# Run specific analysis modules
obliteratus run analysis-config.yaml --preset quick
# Key modules to run first:
# - alignment_imprint: Fingerprint DPO/RLHF/CAI/SFT alignment method
# - concept_geometry: Single direction vs polyhedral cone
# - logit_lens: Which layer decides to refuse
# - anti_ouroboros: Self-repair risk score
# - causal_tracing: Causally necessary components
```
### Steering Vectors (Reversible Alternative)
Instead of permanent weight modification, use inference-time steering:
```python
# Python API only — for user's own projects
from obliteratus.analysis.steering_vectors import SteeringVectorFactory, SteeringHookManager
```
## Ablation Strategies
Beyond direction-based abliteration, OBLITERATUS includes structural ablation strategies:
- **Embedding Ablation** — Target embedding layer components
- **FFN Ablation** — Feed-forward network block removal
- **Head Pruning** — Attention head pruning
- **Layer Removal** — Full layer removal
List all available: `obliteratus strategies`
## Evaluation
OBLITERATUS includes built-in evaluation tools:
- Refusal rate benchmarking
- Perplexity comparison (before/after)
- LM Eval Harness integration for academic benchmarks
- Head-to-head competitor comparison
- Baseline performance tracking
## Platform Support
- **CUDA** — Full support (NVIDIA GPUs)
- **Apple Silicon (MLX)** — Supported via MLX backend
- **CPU** — Supported for tiny models (< 1B params)
## YAML Config Templates
Load templates for reproducible runs via `skill_view`:
- `templates/abliteration-config.yaml` — Standard single-model config
- `templates/analysis-study.yaml` — Pre-abliteration analysis study
- `templates/batch-abliteration.yaml` — Multi-model batch processing
## Telemetry
OBLITERATUS can optionally contribute anonymized run data to a global research dataset.
Enable with `--contribute` flag. No personal data is collected — only model name, method, metrics.
## Common Pitfalls
1. **Don't use `informed` as default** — it's experimental and slower. Use `advanced` for reliable results.
2. **Models under ~1B respond poorly to abliteration** — their refusal behaviors are shallow and fragmented, making clean direction extraction difficult. Expect partial results (20-40% remaining refusal). Models 3B+ have cleaner refusal directions and respond much better (often 0% refusal with `advanced`).
3. **`aggressive` can make things worse** — on small models it can damage coherence and actually increase refusal rate. Only use it if `advanced` leaves > 10% refusals on a 3B+ model.
4. **Always check perplexity** — if it spikes > 15%, the model is damaged. Reduce aggressiveness.
5. **MoE models need special handling** — use `nuclear` method for Mixtral, DeepSeek-MoE, etc.
6. **Quantized models can't be re-quantized** — abliterate the full-precision model, then quantize the output.
7. **VRAM estimation is approximate** — 4-bit quant helps but peak usage can spike during extraction.
8. **Reasoning models are sensitive** — use `surgical` for R1 distills to preserve chain-of-thought.
9. **Check `obliteratus recommend`** — telemetry data may have better parameters than defaults.
10. **AGPL license** — never `import obliteratus` in MIT/Apache projects. CLI invocation only.
11. **Large models (70B+)** — always use `--large-model` flag for conservative defaults.
12. **Spectral certification RED is common** — the spectral check often flags "incomplete" even when practical refusal rate is 0%. Check actual refusal rate rather than relying on spectral certification alone.
## Complementary Skills
- **vllm** — Serve abliterated models with high throughput
- **gguf** — Convert abliterated models to GGUF for llama.cpp
- **huggingface-tokenizers** — Work with model tokenizers

View file

@ -0,0 +1,166 @@
# OBLITERATUS Analysis Modules — Reference
OBLITERATUS includes 28 analysis modules for mechanistic interpretability of refusal in LLMs.
These modules help understand how and where refusal behaviors are encoded before performing abliteration.
---
## Core Analysis (Run These First)
### 1. Alignment Imprint Detection (`alignment_imprint.py`)
Fingerprints whether a model was trained via DPO, RLHF, CAI, or SFT.
This determines which extraction strategy will work best.
### 2. Concept Cone Geometry (`concept_geometry.py`)
Determines if refusal is a single linear direction or a polyhedral cone
(set of multiple mechanisms). Single-direction models respond well to `basic`;
polyhedral models need `advanced` or `surgical`.
### 3. Refusal Logit Lens (`logit_lens.py`)
Identifies the specific layer where a model "decides" to refuse by decoding
intermediate layer representations into token space.
### 4. Ouroboros Detection (`anti_ouroboros.py`)
Identifies if a model attempts to "self-repair" refusal behaviors after
excision. Reports a risk score (0-1). High scores mean additional refinement
passes are needed.
### 5. Causal Tracing (`causal_tracing.py`)
Identifies which components (layers, heads, MLPs) are causally necessary
for refusal behavior using activation patching.
---
## Geometric Analysis
### 6. Cross-Layer Alignment (`cross_layer.py`)
Measures how refusal directions align across different layers. High alignment
means the refusal signal is consistent; low alignment suggests layer-specific
mechanisms.
### 7. Residual Stream Decomposition (`residual_stream.py`)
Decomposes the residual stream into attention and MLP contributions to
understand which component type contributes more to refusal.
### 8. Riemannian Manifold Geometry (`riemannian_manifold.py`)
Analyzes the curvature and geometry of the weight manifold near refusal
directions. Informs how aggressively projections can be applied without
damaging the manifold structure.
### 9. Whitened SVD (`whitened_svd.py`)
Covariance-normalized SVD extraction that separates guardrail signals from
natural activation variance. More precise than standard SVD for models with
high activation variance.
### 10. Concept Cone Geometry (extended)
Maps the full polyhedral structure of refusal, including cone angles,
face counts, and intersection patterns.
---
## Probing & Classification
### 11. Activation Probing (`activation_probing.py`)
Post-excision verification — probes for residual refusal concepts after
abliteration to ensure complete removal.
### 12. Probing Classifiers (`probing_classifiers.py`)
Trains linear classifiers to detect refusal in activations. Used both
before (to verify refusal exists) and after (to verify it's gone).
### 13. Activation Patching (`activation_patching.py`)
Interchange interventions — swaps activations between refused and complied
runs to identify causal components.
### 14. Tuned Lens (`tuned_lens.py`)
Trained version of logit lens that provides more accurate per-layer
decoding by learning affine transformations for each layer.
### 15. Multi-Token Position Analysis (`multi_token_position.py`)
Analyzes refusal signals across multiple token positions, not just the
last token. Important for models that distribute refusal across the sequence.
---
## Abliteration & Manipulation
### 16. SAE-Based Abliteration (`sae_abliteration.py`)
Uses Sparse Autoencoder features to identify and remove specific refusal
features. More surgical than direction-based methods.
### 17. Steering Vectors (`steering_vectors.py`)
Creates and applies inference-time steering vectors for reversible refusal
modification. Includes `SteeringVectorFactory` and `SteeringHookManager`.
### 18. LEACE Concept Erasure (`leace.py`)
Linear Erasure via Closed-form Estimation — mathematically optimal linear
concept removal. Available as both analysis module and direction extraction method.
### 19. Sparse Surgery (`sparse_surgery.py`)
High-precision weight modification targeting individual neurons and
weight matrix entries rather than full directions.
### 20. Conditional Abliteration (`conditional_abliteration.py`)
Targeted removal that only affects specific refusal categories while
preserving others (e.g., remove weapons refusal but keep CSAM refusal).
---
## Transfer & Robustness
### 21. Cross-Model Transfer (`cross_model_transfer.py`)
Tests whether refusal directions extracted from one model transfer to
another architecture. Measures universality of guardrail directions.
### 22. Defense Robustness (`defense_robustness.py`)
Evaluates how robust the abliteration is against various defense mechanisms
and re-alignment attempts.
### 23. Spectral Certification (`spectral_certification.py`)
Provides mathematical bounds on the completeness of refusal removal
using spectral analysis of the projection.
### 24. Wasserstein Optimal Extraction (`wasserstein_optimal.py`)
Uses optimal transport theory for more precise direction extraction
that minimizes distribution shift.
### 25. Wasserstein Transfer (`wasserstein_transfer.py`)
Distribution transfer between models using Wasserstein distance
for cross-architecture refusal direction mapping.
---
## Advanced / Research
### 26. Bayesian Kernel Projection (`bayesian_kernel_projection.py`)
Probabilistic feature mapping that estimates uncertainty in refusal
direction identification.
### 27. Cross-Model Universality Index
Measures if guardrail directions generalize across different model
architectures and training regimes.
### 28. Visualization (`visualization.py`)
Plotting and graphing utilities for all analysis modules. Generates
heatmaps, direction plots, and layer-wise analysis charts.
---
## Running Analysis
### Via CLI
```bash
# Run analysis from a YAML config
obliteratus run analysis-study.yaml --preset quick
# Available study presets:
# quick — Fast sanity check (2-3 modules)
# full — All core + geometric analysis
# jailbreak — Refusal circuit localization
# knowledge — Knowledge preservation analysis
# robustness — Stress testing / defense evaluation
```
### Via YAML Config
See the `templates/analysis-study.yaml` template for a complete example.
Load with: `skill_view(name="obliteratus", file_path="templates/analysis-study.yaml")`

View file

@ -0,0 +1,141 @@
# OBLITERATUS Methods — Detailed Guide
> The CLI accepts 9 methods via `--method`: basic, advanced, aggressive, spectral_cascade,
> informed, surgical, optimized, inverted, nuclear.
> Four additional methods (failspy, gabliteration, heretic, rdo) are available only via the Python API.
## How Abliteration Works (Theory)
Abliteration identifies a "refusal direction" — a vector in the model's activation space that
corresponds to refusal behavior — and projects it out of the weight matrices.
Mathematically: `W_new = W_old - (W_old @ d @ d.T)` where `d` is the refusal direction.
The key challenge is finding accurate refusal directions without damaging other capabilities.
---
## Direction Extraction Methods
Before projecting, OBLITERATUS extracts refusal directions using one of three methods:
| Method | Flag | Description | Best For |
|:-------|:-----|:------------|:---------|
| Diff-in-Means | `--direction-method diff_means` | Difference between mean activations on refused vs. complied prompts | Default, fast, robust |
| SVD | `--direction-method svd` | Multi-direction extraction via Singular Value Decomposition | Complex alignment, multiple refusal mechanisms |
| LEACE | `--direction-method leace` | Linear Erasure via Closed-form Estimation — mathematically optimal | Maximum precision, research |
---
## Method Details
### basic
- **Directions:** 1 (single diff-in-means vector)
- **Speed:** Fast (~5-10 min for 8B model)
- **Risk:** Low
- **Use case:** Quick tests, prototyping, evaluating if abliteration works for a model
- **How it works:** Extracts one refusal direction and projects it out uniformly across all layers.
### advanced (DEFAULT — RECOMMENDED)
- **Directions:** 4 (multi-direction SVD)
- **Speed:** Medium (~10-20 min for 8B model)
- **Risk:** Low-Medium
- **Refinement passes:** 2
- **Use case:** Default for most models. Well-tested and reliable.
- **How it works:** Extracts multiple refusal directions via SVD, applies norm-preserving bi-projection to maintain weight matrix norms. Two refinement passes catch residual refusal.
### aggressive
- **Directions:** 8+ (whitened SVD + jailbreak-contrastive)
- **Speed:** Medium-Slow
- **Risk:** Medium-High (may damage coherence)
- **Use case:** When `advanced` leaves > 10% refusals. Stubborn models.
- **How it works:** Uses whitened SVD for covariance-normalized extraction, adds jailbreak-contrastive directions, performs attention head surgery on the most refusal-active heads.
### spectral_cascade
- **Speed:** Medium
- **Risk:** Medium
- **Use case:** Research, novel approaches
- **How it works:** DCT (Discrete Cosine Transform) frequency-domain decomposition of refusal signals. Separates high-frequency (surface-level) from low-frequency (deep) refusal patterns.
### informed (EXPERIMENTAL)
- **Speed:** Slow (~20-40 min for 8B model)
- **Risk:** Variable — results depend on analysis quality
- **Use case:** When you want auto-configuration, but be aware this is experimental and may not outperform `advanced`.
- **How it works:** Runs 4 analysis modules first (alignment imprint, concept geometry, logit lens, ouroboros detection), then auto-configures extraction strategy. Includes an "Ouroboros loop" that detects and counteracts self-repair.
- **Note:** The auto-detection can sometimes misconfigure. If results are poor, fall back to `advanced`.
### surgical
- **Speed:** Very slow (~1-2 hrs for 8B model)
- **Risk:** Low (very precise)
- **Use case:** Reasoning models (R1 distills, QwQ, etc.) where chain-of-thought must be preserved.
- **How it works:** Uses SAE (Sparse Autoencoder) features + individual neuron masking + attention head surgery + per-expert decomposition (for MoE). CoT-aware — identifies and protects reasoning-critical directions before projecting.
### optimized
- **Speed:** Very slow (hours — runs many trials)
- **Risk:** Low (finds optimal parameters)
- **Use case:** When quality matters more than speed. Production models.
- **How it works:** Bayesian hyperparameter search via Optuna TPE sampler. Optimizes n_directions, regularization, refinement passes, and layer selection jointly. Evaluates each configuration on refusal rate + perplexity.
### inverted
- **Speed:** Fast
- **Risk:** High (model behavior changes dramatically)
- **Use case:** Research, studying refusal mechanisms
- **How it works:** Instead of projecting out the refusal direction, reflects it. The model actively complies rather than passively not-refusing. Useful for understanding the geometry of alignment.
### nuclear
- **Speed:** Slow
- **Risk:** Medium-High
- **Use case:** Stubborn MoE models (DeepSeek-MoE, Mixtral, etc.)
- **How it works:** Combines expert-granular abliteration (EGA), steering vector injection, attention head pruning, and multi-pass refinement. Decomposes refusal signals into per-expert components for MoE architectures.
---
## Method Selection Flowchart
```
Is this a quick test?
→ YES: basic
→ NO: continue
Is it an MoE model (Mixtral, DeepSeek-MoE)?
→ YES: nuclear
→ NO: continue
Is it a reasoning model (R1, QwQ, CoT-focused)?
→ YES: surgical
→ NO: continue
Do you need the absolute best quality and have time?
→ YES: optimized
→ NO: advanced (recommended default)
Did advanced leave > 10% refusals?
→ YES: aggressive
→ Still refusing: nuclear
```
---
## Key Parameters
| Parameter | Range | Default | Effect |
|:----------|:------|:--------|:-------|
| `--n-directions` | 1-32 | method-dependent | More directions = more complete removal, but higher damage risk |
| `--regularization` | 0.0-1.0 | 0.1 | Higher = more conservative (less removal, less damage) |
| `--refinement-passes` | 1-5 | 2 | More passes catch residual refusal, but diminishing returns |
| `--quantization` | 4bit, 8bit | none | Reduces VRAM usage; quality impact minimal for extraction |
| `--verify-sample-size` | 10-200 | 20 | More samples = more accurate refusal rate estimate |
---
## Troubleshooting
| Problem | Likely Cause | Fix |
|:--------|:-------------|:----|
| Refusal rate > 20% | Too few directions | Increase `--n-directions`, try `aggressive` |
| Refusal rate 5-20% | Residual refusal | Add `--refinement-passes 3`, try `--direction-method svd` |
| Perplexity spike > 20% | Over-aggressive removal | Reduce `--n-directions`, increase `--regularization` |
| Repetitive output | Weight matrix damage | Use `basic` with fewer directions, check norm preservation |
| MoE model still refuses | Non-expert-aware method | Switch to `nuclear` |
| Reasoning degraded | CoT directions damaged | Use `surgical` method |
| OOM during extraction | Insufficient VRAM | Add `--quantization 4bit` and/or `--large-model` |

View file

@ -0,0 +1,33 @@
# OBLITERATUS Abliteration Config
# Usage: obliteratus run this-file.yaml
#
# This is for reproducible, version-controlled abliteration runs.
# For one-off usage, the CLI flags are simpler.
# Model to abliterate
model:
name: "meta-llama/Llama-3.1-8B-Instruct"
dtype: "bfloat16" # float16, bfloat16, float32
quantization: null # null, "4bit", "8bit"
device: "auto" # auto, cuda, cuda:0, cpu
# Abliteration method and parameters
abliteration:
method: "informed" # See SKILL.md Step 4 for all 13 methods
n_directions: null # null = auto-detect, or integer (e.g., 8)
regularization: 0.0 # 0.0-1.0, fraction of original to preserve
refinement_passes: 1 # Iterative passes (increase for self-repair)
norm_preserve: true # Keep weight norms intact after projection
# Output
output:
directory: "./abliterated-models"
save_metadata: true # Save abliteration_metadata.json alongside model
contribute: false # Save community contribution data
# Verification
verify:
enabled: true
test_prompts: null # null = use built-in test prompts
compute_perplexity: true
compute_kl: true

View file

@ -0,0 +1,40 @@
# OBLITERATUS Analysis Study Config
# Usage: obliteratus run this-file.yaml --preset jailbreak
#
# Run analysis modules to understand refusal geometry BEFORE abliterating.
# Useful for research or when you want to understand what you're removing.
# Model to analyze
model:
name: "meta-llama/Llama-3.1-8B-Instruct"
dtype: "bfloat16"
quantization: "4bit" # Saves VRAM for analysis
device: "auto"
# Study configuration
study:
# Available presets: quick, full, attention, jailbreak, guardrail, knowledge
preset: "jailbreak"
# Or specify individual strategies:
# strategies:
# - layer_removal
# - head_pruning
# - ffn_ablation
# - embedding_ablation
# Analysis modules to run (subset of the 27 available)
analysis:
- alignment_imprint # Detect DPO/RLHF/CAI/SFT training method
- concept_geometry # Map refusal cone geometry
- logit_lens # Find which layer decides to refuse
- anti_ouroboros # Detect self-repair tendency
- cross_layer # Cross-layer alignment clustering
- causal_tracing # Causal necessity of components
- residual_stream # Attention vs MLP contribution
# Output
output:
directory: "./analysis-results"
save_plots: true # Generate matplotlib visualizations
save_report: true # Generate markdown report

View file

@ -0,0 +1,41 @@
# OBLITERATUS Batch Abliteration Config
# Abliterate multiple models with the same method for comparison.
#
# Run each one sequentially:
# for model in models; do obliteratus obliterate $model --method informed; done
#
# Or use this as a reference for which models to process.
# Common settings
defaults:
method: "informed"
quantization: "4bit"
output_dir: "./abliterated-models"
# Models to process (grouped by compute tier)
models:
# Small (4-8 GB VRAM)
small:
- "Qwen/Qwen2.5-1.5B-Instruct"
- "microsoft/Phi-3.5-mini-instruct"
- "meta-llama/Llama-3.2-3B-Instruct"
# Medium (8-16 GB VRAM)
medium:
- "meta-llama/Llama-3.1-8B-Instruct"
- "mistralai/Mistral-7B-Instruct-v0.3"
- "google/gemma-2-9b-it"
- "Qwen/Qwen2.5-7B-Instruct"
# Large (24 GB VRAM, 4-bit quantization)
large:
- "Qwen/Qwen2.5-14B-Instruct"
- "Qwen/Qwen3-32B"
- "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
# Per-model method overrides (optional)
overrides:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B":
method: "surgical" # CoT-aware for reasoning models
"mistralai/Mixtral-8x7B-Instruct-v0.1":
method: "nuclear" # Expert-granular for MoE models

View file

@ -0,0 +1,655 @@
---
name: outlines
description: Guarantee valid JSON/XML/code structure during generation, use Pydantic models for type-safe outputs, support local models (Transformers, vLLM), and maximize inference speed with Outlines - dottxt.ai's structured generation library
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [outlines, transformers, vllm, pydantic]
metadata:
hermes:
tags: [Prompt Engineering, Outlines, Structured Generation, JSON Schema, Pydantic, Local Models, Grammar-Based Generation, vLLM, Transformers, Type Safety]
---
# Outlines: Structured Text Generation
## When to Use This Skill
Use Outlines when you need to:
- **Guarantee valid JSON/XML/code** structure during generation
- **Use Pydantic models** for type-safe outputs
- **Support local models** (Transformers, llama.cpp, vLLM)
- **Maximize inference speed** with zero-overhead structured generation
- **Generate against JSON schemas** automatically
- **Control token sampling** at the grammar level
**GitHub Stars**: 8,000+ | **From**: dottxt.ai (formerly .txt)
## Installation
```bash
# Base installation
pip install outlines
# With specific backends
pip install outlines transformers # Hugging Face models
pip install outlines llama-cpp-python # llama.cpp
pip install outlines vllm # vLLM for high-throughput
```
## Quick Start
### Basic Example: Classification
```python
import outlines
from typing import Literal
# Load model
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
# Generate with type constraint
prompt = "Sentiment of 'This product is amazing!': "
generator = outlines.generate.choice(model, ["positive", "negative", "neutral"])
sentiment = generator(prompt)
print(sentiment) # "positive" (guaranteed one of these)
```
### With Pydantic Models
```python
from pydantic import BaseModel
import outlines
class User(BaseModel):
name: str
age: int
email: str
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
# Generate structured output
prompt = "Extract user: John Doe, 30 years old, john@example.com"
generator = outlines.generate.json(model, User)
user = generator(prompt)
print(user.name) # "John Doe"
print(user.age) # 30
print(user.email) # "john@example.com"
```
## Core Concepts
### 1. Constrained Token Sampling
Outlines uses Finite State Machines (FSM) to constrain token generation at the logit level.
**How it works:**
1. Convert schema (JSON/Pydantic/regex) to context-free grammar (CFG)
2. Transform CFG into Finite State Machine (FSM)
3. Filter invalid tokens at each step during generation
4. Fast-forward when only one valid token exists
**Benefits:**
- **Zero overhead**: Filtering happens at token level
- **Speed improvement**: Fast-forward through deterministic paths
- **Guaranteed validity**: Invalid outputs impossible
```python
import outlines
# Pydantic model -> JSON schema -> CFG -> FSM
class Person(BaseModel):
name: str
age: int
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
# Behind the scenes:
# 1. Person -> JSON schema
# 2. JSON schema -> CFG
# 3. CFG -> FSM
# 4. FSM filters tokens during generation
generator = outlines.generate.json(model, Person)
result = generator("Generate person: Alice, 25")
```
### 2. Structured Generators
Outlines provides specialized generators for different output types.
#### Choice Generator
```python
# Multiple choice selection
generator = outlines.generate.choice(
model,
["positive", "negative", "neutral"]
)
sentiment = generator("Review: This is great!")
# Result: One of the three choices
```
#### JSON Generator
```python
from pydantic import BaseModel
class Product(BaseModel):
name: str
price: float
in_stock: bool
# Generate valid JSON matching schema
generator = outlines.generate.json(model, Product)
product = generator("Extract: iPhone 15, $999, available")
# Guaranteed valid Product instance
print(type(product)) # <class '__main__.Product'>
```
#### Regex Generator
```python
# Generate text matching regex
generator = outlines.generate.regex(
model,
r"[0-9]{3}-[0-9]{3}-[0-9]{4}" # Phone number pattern
)
phone = generator("Generate phone number:")
# Result: "555-123-4567" (guaranteed to match pattern)
```
#### Integer/Float Generators
```python
# Generate specific numeric types
int_generator = outlines.generate.integer(model)
age = int_generator("Person's age:") # Guaranteed integer
float_generator = outlines.generate.float(model)
price = float_generator("Product price:") # Guaranteed float
```
### 3. Model Backends
Outlines supports multiple local and API-based backends.
#### Transformers (Hugging Face)
```python
import outlines
# Load from Hugging Face
model = outlines.models.transformers(
"microsoft/Phi-3-mini-4k-instruct",
device="cuda" # Or "cpu"
)
# Use with any generator
generator = outlines.generate.json(model, YourModel)
```
#### llama.cpp
```python
# Load GGUF model
model = outlines.models.llamacpp(
"./models/llama-3.1-8b-instruct.Q4_K_M.gguf",
n_gpu_layers=35
)
generator = outlines.generate.json(model, YourModel)
```
#### vLLM (High Throughput)
```python
# For production deployments
model = outlines.models.vllm(
"meta-llama/Llama-3.1-8B-Instruct",
tensor_parallel_size=2 # Multi-GPU
)
generator = outlines.generate.json(model, YourModel)
```
#### OpenAI (Limited Support)
```python
# Basic OpenAI support
model = outlines.models.openai(
"gpt-4o-mini",
api_key="your-api-key"
)
# Note: Some features limited with API models
generator = outlines.generate.json(model, YourModel)
```
### 4. Pydantic Integration
Outlines has first-class Pydantic support with automatic schema translation.
#### Basic Models
```python
from pydantic import BaseModel, Field
class Article(BaseModel):
title: str = Field(description="Article title")
author: str = Field(description="Author name")
word_count: int = Field(description="Number of words", gt=0)
tags: list[str] = Field(description="List of tags")
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, Article)
article = generator("Generate article about AI")
print(article.title)
print(article.word_count) # Guaranteed > 0
```
#### Nested Models
```python
class Address(BaseModel):
street: str
city: str
country: str
class Person(BaseModel):
name: str
age: int
address: Address # Nested model
generator = outlines.generate.json(model, Person)
person = generator("Generate person in New York")
print(person.address.city) # "New York"
```
#### Enums and Literals
```python
from enum import Enum
from typing import Literal
class Status(str, Enum):
PENDING = "pending"
APPROVED = "approved"
REJECTED = "rejected"
class Application(BaseModel):
applicant: str
status: Status # Must be one of enum values
priority: Literal["low", "medium", "high"] # Must be one of literals
generator = outlines.generate.json(model, Application)
app = generator("Generate application")
print(app.status) # Status.PENDING (or APPROVED/REJECTED)
```
## Common Patterns
### Pattern 1: Data Extraction
```python
from pydantic import BaseModel
import outlines
class CompanyInfo(BaseModel):
name: str
founded_year: int
industry: str
employees: int
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, CompanyInfo)
text = """
Apple Inc. was founded in 1976 in the technology industry.
The company employs approximately 164,000 people worldwide.
"""
prompt = f"Extract company information:\n{text}\n\nCompany:"
company = generator(prompt)
print(f"Name: {company.name}")
print(f"Founded: {company.founded_year}")
print(f"Industry: {company.industry}")
print(f"Employees: {company.employees}")
```
### Pattern 2: Classification
```python
from typing import Literal
import outlines
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
# Binary classification
generator = outlines.generate.choice(model, ["spam", "not_spam"])
result = generator("Email: Buy now! 50% off!")
# Multi-class classification
categories = ["technology", "business", "sports", "entertainment"]
category_gen = outlines.generate.choice(model, categories)
category = category_gen("Article: Apple announces new iPhone...")
# With confidence
class Classification(BaseModel):
label: Literal["positive", "negative", "neutral"]
confidence: float
classifier = outlines.generate.json(model, Classification)
result = classifier("Review: This product is okay, nothing special")
```
### Pattern 3: Structured Forms
```python
class UserProfile(BaseModel):
full_name: str
age: int
email: str
phone: str
country: str
interests: list[str]
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, UserProfile)
prompt = """
Extract user profile from:
Name: Alice Johnson
Age: 28
Email: alice@example.com
Phone: 555-0123
Country: USA
Interests: hiking, photography, cooking
"""
profile = generator(prompt)
print(profile.full_name)
print(profile.interests) # ["hiking", "photography", "cooking"]
```
### Pattern 4: Multi-Entity Extraction
```python
class Entity(BaseModel):
name: str
type: Literal["PERSON", "ORGANIZATION", "LOCATION"]
class DocumentEntities(BaseModel):
entities: list[Entity]
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, DocumentEntities)
text = "Tim Cook met with Satya Nadella at Microsoft headquarters in Redmond."
prompt = f"Extract entities from: {text}"
result = generator(prompt)
for entity in result.entities:
print(f"{entity.name} ({entity.type})")
```
### Pattern 5: Code Generation
```python
class PythonFunction(BaseModel):
function_name: str
parameters: list[str]
docstring: str
body: str
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, PythonFunction)
prompt = "Generate a Python function to calculate factorial"
func = generator(prompt)
print(f"def {func.function_name}({', '.join(func.parameters)}):")
print(f' """{func.docstring}"""')
print(f" {func.body}")
```
### Pattern 6: Batch Processing
```python
def batch_extract(texts: list[str], schema: type[BaseModel]):
"""Extract structured data from multiple texts."""
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, schema)
results = []
for text in texts:
result = generator(f"Extract from: {text}")
results.append(result)
return results
class Person(BaseModel):
name: str
age: int
texts = [
"John is 30 years old",
"Alice is 25 years old",
"Bob is 40 years old"
]
people = batch_extract(texts, Person)
for person in people:
print(f"{person.name}: {person.age}")
```
## Backend Configuration
### Transformers
```python
import outlines
# Basic usage
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
# GPU configuration
model = outlines.models.transformers(
"microsoft/Phi-3-mini-4k-instruct",
device="cuda",
model_kwargs={"torch_dtype": "float16"}
)
# Popular models
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.3")
model = outlines.models.transformers("Qwen/Qwen2.5-7B-Instruct")
```
### llama.cpp
```python
# Load GGUF model
model = outlines.models.llamacpp(
"./models/llama-3.1-8b.Q4_K_M.gguf",
n_ctx=4096, # Context window
n_gpu_layers=35, # GPU layers
n_threads=8 # CPU threads
)
# Full GPU offload
model = outlines.models.llamacpp(
"./models/model.gguf",
n_gpu_layers=-1 # All layers on GPU
)
```
### vLLM (Production)
```python
# Single GPU
model = outlines.models.vllm("meta-llama/Llama-3.1-8B-Instruct")
# Multi-GPU
model = outlines.models.vllm(
"meta-llama/Llama-3.1-70B-Instruct",
tensor_parallel_size=4 # 4 GPUs
)
# With quantization
model = outlines.models.vllm(
"meta-llama/Llama-3.1-8B-Instruct",
quantization="awq" # Or "gptq"
)
```
## Best Practices
### 1. Use Specific Types
```python
# ✅ Good: Specific types
class Product(BaseModel):
name: str
price: float # Not str
quantity: int # Not str
in_stock: bool # Not str
# ❌ Bad: Everything as string
class Product(BaseModel):
name: str
price: str # Should be float
quantity: str # Should be int
```
### 2. Add Constraints
```python
from pydantic import Field
# ✅ Good: With constraints
class User(BaseModel):
name: str = Field(min_length=1, max_length=100)
age: int = Field(ge=0, le=120)
email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$")
# ❌ Bad: No constraints
class User(BaseModel):
name: str
age: int
email: str
```
### 3. Use Enums for Categories
```python
# ✅ Good: Enum for fixed set
class Priority(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
class Task(BaseModel):
title: str
priority: Priority
# ❌ Bad: Free-form string
class Task(BaseModel):
title: str
priority: str # Can be anything
```
### 4. Provide Context in Prompts
```python
# ✅ Good: Clear context
prompt = """
Extract product information from the following text.
Text: iPhone 15 Pro costs $999 and is currently in stock.
Product:
"""
# ❌ Bad: Minimal context
prompt = "iPhone 15 Pro costs $999 and is currently in stock."
```
### 5. Handle Optional Fields
```python
from typing import Optional
# ✅ Good: Optional fields for incomplete data
class Article(BaseModel):
title: str # Required
author: Optional[str] = None # Optional
date: Optional[str] = None # Optional
tags: list[str] = [] # Default empty list
# Can succeed even if author/date missing
```
## Comparison to Alternatives
| Feature | Outlines | Instructor | Guidance | LMQL |
|---------|----------|------------|----------|------|
| Pydantic Support | ✅ Native | ✅ Native | ❌ No | ❌ No |
| JSON Schema | ✅ Yes | ✅ Yes | ⚠️ Limited | ✅ Yes |
| Regex Constraints | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes |
| Local Models | ✅ Full | ⚠️ Limited | ✅ Full | ✅ Full |
| API Models | ⚠️ Limited | ✅ Full | ✅ Full | ✅ Full |
| Zero Overhead | ✅ Yes | ❌ No | ⚠️ Partial | ✅ Yes |
| Automatic Retrying | ❌ No | ✅ Yes | ❌ No | ❌ No |
| Learning Curve | Low | Low | Low | High |
**When to choose Outlines:**
- Using local models (Transformers, llama.cpp, vLLM)
- Need maximum inference speed
- Want Pydantic model support
- Require zero-overhead structured generation
- Control token sampling process
**When to choose alternatives:**
- Instructor: Need API models with automatic retrying
- Guidance: Need token healing and complex workflows
- LMQL: Prefer declarative query syntax
## Performance Characteristics
**Speed:**
- **Zero overhead**: Structured generation as fast as unconstrained
- **Fast-forward optimization**: Skips deterministic tokens
- **1.2-2x faster** than post-generation validation approaches
**Memory:**
- FSM compiled once per schema (cached)
- Minimal runtime overhead
- Efficient with vLLM for high throughput
**Accuracy:**
- **100% valid outputs** (guaranteed by FSM)
- No retry loops needed
- Deterministic token filtering
## Resources
- **Documentation**: https://outlines-dev.github.io/outlines
- **GitHub**: https://github.com/outlines-dev/outlines (8k+ stars)
- **Discord**: https://discord.gg/R9DSu34mGd
- **Blog**: https://blog.dottxt.co
## See Also
- `references/json_generation.md` - Comprehensive JSON and Pydantic patterns
- `references/backends.md` - Backend-specific configuration
- `references/examples.md` - Production-ready examples

View file

@ -0,0 +1,615 @@
# Backend Configuration Guide
Complete guide to configuring Outlines with different model backends.
## Table of Contents
- Local Models (Transformers, llama.cpp, vLLM)
- API Models (OpenAI)
- Performance Comparison
- Configuration Examples
- Production Deployment
## Transformers (Hugging Face)
### Basic Setup
```python
import outlines
# Load model from Hugging Face
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
# Use with generator
generator = outlines.generate.json(model, YourModel)
result = generator("Your prompt")
```
### GPU Configuration
```python
# Use CUDA GPU
model = outlines.models.transformers(
"microsoft/Phi-3-mini-4k-instruct",
device="cuda"
)
# Use specific GPU
model = outlines.models.transformers(
"microsoft/Phi-3-mini-4k-instruct",
device="cuda:0" # GPU 0
)
# Use CPU
model = outlines.models.transformers(
"microsoft/Phi-3-mini-4k-instruct",
device="cpu"
)
# Use Apple Silicon MPS
model = outlines.models.transformers(
"microsoft/Phi-3-mini-4k-instruct",
device="mps"
)
```
### Advanced Configuration
```python
# FP16 for faster inference
model = outlines.models.transformers(
"microsoft/Phi-3-mini-4k-instruct",
device="cuda",
model_kwargs={
"torch_dtype": "float16"
}
)
# 8-bit quantization (less memory)
model = outlines.models.transformers(
"microsoft/Phi-3-mini-4k-instruct",
device="cuda",
model_kwargs={
"load_in_8bit": True,
"device_map": "auto"
}
)
# 4-bit quantization (even less memory)
model = outlines.models.transformers(
"meta-llama/Llama-3.1-70B-Instruct",
device="cuda",
model_kwargs={
"load_in_4bit": True,
"device_map": "auto",
"bnb_4bit_compute_dtype": "float16"
}
)
# Multi-GPU
model = outlines.models.transformers(
"meta-llama/Llama-3.1-70B-Instruct",
device="cuda",
model_kwargs={
"device_map": "auto", # Automatic GPU distribution
"max_memory": {0: "40GB", 1: "40GB"} # Per-GPU limits
}
)
```
### Popular Models
```python
# Phi-4 (Microsoft)
model = outlines.models.transformers("microsoft/Phi-4-mini-instruct")
model = outlines.models.transformers("microsoft/Phi-3-medium-4k-instruct")
# Llama 3.1 (Meta)
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
model = outlines.models.transformers("meta-llama/Llama-3.1-70B-Instruct")
model = outlines.models.transformers("meta-llama/Llama-3.1-405B-Instruct")
# Mistral (Mistral AI)
model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.3")
model = outlines.models.transformers("mistralai/Mixtral-8x7B-Instruct-v0.1")
model = outlines.models.transformers("mistralai/Mixtral-8x22B-Instruct-v0.1")
# Qwen (Alibaba)
model = outlines.models.transformers("Qwen/Qwen2.5-7B-Instruct")
model = outlines.models.transformers("Qwen/Qwen2.5-14B-Instruct")
model = outlines.models.transformers("Qwen/Qwen2.5-72B-Instruct")
# Gemma (Google)
model = outlines.models.transformers("google/gemma-2-9b-it")
model = outlines.models.transformers("google/gemma-2-27b-it")
# Llava (Vision)
model = outlines.models.transformers("llava-hf/llava-v1.6-mistral-7b-hf")
```
### Custom Model Loading
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import outlines
# Load model manually
tokenizer = AutoTokenizer.from_pretrained("your-model")
model_hf = AutoModelForCausalLM.from_pretrained(
"your-model",
device_map="auto",
torch_dtype="float16"
)
# Use with Outlines
model = outlines.models.transformers(
model=model_hf,
tokenizer=tokenizer
)
```
## llama.cpp
### Basic Setup
```python
import outlines
# Load GGUF model
model = outlines.models.llamacpp(
"./models/llama-3.1-8b-instruct.Q4_K_M.gguf",
n_ctx=4096 # Context window
)
# Use with generator
generator = outlines.generate.json(model, YourModel)
```
### GPU Configuration
```python
# CPU only
model = outlines.models.llamacpp(
"./models/model.gguf",
n_ctx=4096,
n_threads=8 # Use 8 CPU threads
)
# GPU offload (partial)
model = outlines.models.llamacpp(
"./models/model.gguf",
n_ctx=4096,
n_gpu_layers=35, # Offload 35 layers to GPU
n_threads=4 # CPU threads for remaining layers
)
# Full GPU offload
model = outlines.models.llamacpp(
"./models/model.gguf",
n_ctx=8192,
n_gpu_layers=-1 # All layers on GPU
)
```
### Advanced Configuration
```python
model = outlines.models.llamacpp(
"./models/llama-3.1-8b.Q4_K_M.gguf",
n_ctx=8192, # Context window (tokens)
n_gpu_layers=35, # GPU layers
n_threads=8, # CPU threads
n_batch=512, # Batch size for prompt processing
use_mmap=True, # Memory-map model file (faster loading)
use_mlock=False, # Lock model in RAM (prevents swapping)
seed=42, # Random seed for reproducibility
verbose=False # Suppress verbose output
)
```
### Quantization Formats
```python
# Q4_K_M (4-bit, recommended for most cases)
# - Size: ~4.5GB for 7B model
# - Quality: Good
# - Speed: Fast
model = outlines.models.llamacpp("./models/model.Q4_K_M.gguf")
# Q5_K_M (5-bit, better quality)
# - Size: ~5.5GB for 7B model
# - Quality: Very good
# - Speed: Slightly slower than Q4
model = outlines.models.llamacpp("./models/model.Q5_K_M.gguf")
# Q6_K (6-bit, high quality)
# - Size: ~6.5GB for 7B model
# - Quality: Excellent
# - Speed: Slower than Q5
model = outlines.models.llamacpp("./models/model.Q6_K.gguf")
# Q8_0 (8-bit, near-original quality)
# - Size: ~8GB for 7B model
# - Quality: Near FP16
# - Speed: Slower than Q6
model = outlines.models.llamacpp("./models/model.Q8_0.gguf")
# F16 (16-bit float, original quality)
# - Size: ~14GB for 7B model
# - Quality: Original
# - Speed: Slowest
model = outlines.models.llamacpp("./models/model.F16.gguf")
```
### Popular GGUF Models
```python
# Llama 3.1
model = outlines.models.llamacpp("llama-3.1-8b-instruct.Q4_K_M.gguf")
model = outlines.models.llamacpp("llama-3.1-70b-instruct.Q4_K_M.gguf")
# Mistral
model = outlines.models.llamacpp("mistral-7b-instruct-v0.3.Q4_K_M.gguf")
# Phi-4
model = outlines.models.llamacpp("phi-4-mini-instruct.Q4_K_M.gguf")
# Qwen
model = outlines.models.llamacpp("qwen2.5-7b-instruct.Q4_K_M.gguf")
```
### Apple Silicon Optimization
```python
# Optimized for M1/M2/M3 Macs
model = outlines.models.llamacpp(
"./models/llama-3.1-8b.Q4_K_M.gguf",
n_ctx=4096,
n_gpu_layers=-1, # Use Metal GPU acceleration
use_mmap=True, # Efficient memory mapping
n_threads=8 # Use performance cores
)
```
## vLLM (Production)
### Basic Setup
```python
import outlines
# Load model with vLLM
model = outlines.models.vllm("meta-llama/Llama-3.1-8B-Instruct")
# Use with generator
generator = outlines.generate.json(model, YourModel)
```
### Single GPU
```python
model = outlines.models.vllm(
"meta-llama/Llama-3.1-8B-Instruct",
gpu_memory_utilization=0.9, # Use 90% of GPU memory
max_model_len=4096 # Max sequence length
)
```
### Multi-GPU
```python
# Tensor parallelism (split model across GPUs)
model = outlines.models.vllm(
"meta-llama/Llama-3.1-70B-Instruct",
tensor_parallel_size=4, # Use 4 GPUs
gpu_memory_utilization=0.9
)
# Pipeline parallelism (rare, for very large models)
model = outlines.models.vllm(
"meta-llama/Llama-3.1-405B-Instruct",
pipeline_parallel_size=8, # 8-GPU pipeline
tensor_parallel_size=4 # 4-GPU tensor split
# Total: 32 GPUs
)
```
### Quantization
```python
# AWQ quantization (4-bit)
model = outlines.models.vllm(
"meta-llama/Llama-3.1-8B-Instruct",
quantization="awq",
dtype="float16"
)
# GPTQ quantization (4-bit)
model = outlines.models.vllm(
"meta-llama/Llama-3.1-8B-Instruct",
quantization="gptq"
)
# SqueezeLLM quantization
model = outlines.models.vllm(
"meta-llama/Llama-3.1-8B-Instruct",
quantization="squeezellm"
)
```
### Advanced Configuration
```python
model = outlines.models.vllm(
"meta-llama/Llama-3.1-8B-Instruct",
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
max_model_len=8192,
max_num_seqs=256, # Max concurrent sequences
max_num_batched_tokens=8192, # Max tokens per batch
dtype="float16",
trust_remote_code=True,
enforce_eager=False, # Use CUDA graphs (faster)
swap_space=4 # CPU swap space (GB)
)
```
### Batch Processing
```python
# vLLM optimized for high-throughput batch processing
model = outlines.models.vllm(
"meta-llama/Llama-3.1-8B-Instruct",
max_num_seqs=128 # Process 128 sequences in parallel
)
generator = outlines.generate.json(model, YourModel)
# Process many prompts efficiently
prompts = ["prompt1", "prompt2", ..., "prompt100"]
results = [generator(p) for p in prompts]
# vLLM automatically batches and optimizes
```
## OpenAI (Limited Support)
### Basic Setup
```python
import outlines
# Basic OpenAI support
model = outlines.models.openai("gpt-4o-mini", api_key="your-api-key")
# Use with generator
generator = outlines.generate.json(model, YourModel)
result = generator("Your prompt")
```
### Configuration
```python
model = outlines.models.openai(
"gpt-4o-mini",
api_key="your-api-key", # Or set OPENAI_API_KEY env var
max_tokens=2048,
temperature=0.7
)
```
### Available Models
```python
# GPT-4o (latest)
model = outlines.models.openai("gpt-4o")
# GPT-4o Mini (cost-effective)
model = outlines.models.openai("gpt-4o-mini")
# GPT-4 Turbo
model = outlines.models.openai("gpt-4-turbo")
# GPT-3.5 Turbo
model = outlines.models.openai("gpt-3.5-turbo")
```
**Note**: OpenAI support is limited compared to local models. Some advanced features may not work.
## Backend Comparison
### Feature Matrix
| Feature | Transformers | llama.cpp | vLLM | OpenAI |
|---------|-------------|-----------|------|--------|
| Structured Generation | ✅ Full | ✅ Full | ✅ Full | ⚠️ Limited |
| FSM Optimization | ✅ Yes | ✅ Yes | ✅ Yes | ❌ No |
| GPU Support | ✅ Yes | ✅ Yes | ✅ Yes | N/A |
| Multi-GPU | ✅ Yes | ✅ Yes | ✅ Yes | N/A |
| Quantization | ✅ Yes | ✅ Yes | ✅ Yes | N/A |
| High Throughput | ⚠️ Medium | ⚠️ Medium | ✅ Excellent | ⚠️ API-limited |
| Setup Difficulty | Easy | Medium | Medium | Easy |
| Cost | Hardware | Hardware | Hardware | API usage |
### Performance Characteristics
**Transformers:**
- **Latency**: 50-200ms (single request, GPU)
- **Throughput**: 10-50 tokens/sec (depends on hardware)
- **Memory**: 2-4GB per 1B parameters (FP16)
- **Best for**: Development, small-scale deployment, flexibility
**llama.cpp:**
- **Latency**: 30-150ms (single request)
- **Throughput**: 20-150 tokens/sec (depends on quantization)
- **Memory**: 0.5-2GB per 1B parameters (Q4-Q8)
- **Best for**: CPU inference, Apple Silicon, edge deployment, low memory
**vLLM:**
- **Latency**: 30-100ms (single request)
- **Throughput**: 100-1000+ tokens/sec (batch processing)
- **Memory**: 2-4GB per 1B parameters (FP16)
- **Best for**: Production, high-throughput, batch processing, serving
**OpenAI:**
- **Latency**: 200-500ms (API call)
- **Throughput**: API rate limits
- **Memory**: N/A (cloud-based)
- **Best for**: Quick prototyping, no infrastructure
### Memory Requirements
**7B Model:**
- FP16: ~14GB
- 8-bit: ~7GB
- 4-bit: ~4GB
- Q4_K_M (GGUF): ~4.5GB
**13B Model:**
- FP16: ~26GB
- 8-bit: ~13GB
- 4-bit: ~7GB
- Q4_K_M (GGUF): ~8GB
**70B Model:**
- FP16: ~140GB (multi-GPU)
- 8-bit: ~70GB (multi-GPU)
- 4-bit: ~35GB (single A100/H100)
- Q4_K_M (GGUF): ~40GB
## Performance Tuning
### Transformers Optimization
```python
# Use FP16
model = outlines.models.transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
model_kwargs={"torch_dtype": "float16"}
)
# Use flash attention (2-4x faster)
model = outlines.models.transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
model_kwargs={
"torch_dtype": "float16",
"use_flash_attention_2": True
}
)
# Use 8-bit quantization (2x less memory)
model = outlines.models.transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
model_kwargs={
"load_in_8bit": True,
"device_map": "auto"
}
)
```
### llama.cpp Optimization
```python
# Maximize GPU usage
model = outlines.models.llamacpp(
"./models/model.Q4_K_M.gguf",
n_gpu_layers=-1, # All layers on GPU
n_ctx=8192,
n_batch=512 # Larger batch = faster
)
# Optimize for CPU (Apple Silicon)
model = outlines.models.llamacpp(
"./models/model.Q4_K_M.gguf",
n_ctx=4096,
n_threads=8, # Use all performance cores
use_mmap=True
)
```
### vLLM Optimization
```python
# High throughput
model = outlines.models.vllm(
"meta-llama/Llama-3.1-8B-Instruct",
gpu_memory_utilization=0.95, # Use 95% of GPU
max_num_seqs=256, # High concurrency
enforce_eager=False # Use CUDA graphs
)
# Multi-GPU
model = outlines.models.vllm(
"meta-llama/Llama-3.1-70B-Instruct",
tensor_parallel_size=4, # 4 GPUs
gpu_memory_utilization=0.9
)
```
## Production Deployment
### Docker with vLLM
```dockerfile
FROM vllm/vllm-openai:latest
# Install outlines
RUN pip install outlines
# Copy your code
COPY app.py /app/
# Run
CMD ["python", "/app/app.py"]
```
### Environment Variables
```bash
# Transformers cache
export HF_HOME="/path/to/cache"
export TRANSFORMERS_CACHE="/path/to/cache"
# GPU selection
export CUDA_VISIBLE_DEVICES=0,1,2,3
# OpenAI API key
export OPENAI_API_KEY="sk-..."
# Disable tokenizers parallelism warning
export TOKENIZERS_PARALLELISM=false
```
### Model Serving
```python
# Simple HTTP server with vLLM
import outlines
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
# Load model once at startup
model = outlines.models.vllm("meta-llama/Llama-3.1-8B-Instruct")
class User(BaseModel):
name: str
age: int
email: str
generator = outlines.generate.json(model, User)
@app.post("/extract")
def extract(text: str):
result = generator(f"Extract user from: {text}")
return result.model_dump()
```
## Resources
- **Transformers**: https://huggingface.co/docs/transformers
- **llama.cpp**: https://github.com/ggerganov/llama.cpp
- **vLLM**: https://docs.vllm.ai
- **Outlines**: https://github.com/outlines-dev/outlines

View file

@ -0,0 +1,773 @@
# Production-Ready Examples
Real-world examples of using Outlines for structured generation in production systems.
## Table of Contents
- Data Extraction
- Classification Systems
- Form Processing
- Multi-Entity Extraction
- Code Generation
- Batch Processing
- Production Patterns
## Data Extraction
### Basic Information Extraction
```python
from pydantic import BaseModel, Field
import outlines
class PersonInfo(BaseModel):
name: str = Field(description="Full name")
age: int = Field(ge=0, le=120)
occupation: str
email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$")
location: str
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, PersonInfo)
text = """
Dr. Sarah Johnson is a 42-year-old research scientist at MIT.
She can be reached at sarah.j@mit.edu and currently lives in Cambridge, MA.
"""
prompt = f"Extract person information from:\n{text}\n\nPerson:"
person = generator(prompt)
print(f"Name: {person.name}")
print(f"Age: {person.age}")
print(f"Occupation: {person.occupation}")
print(f"Email: {person.email}")
print(f"Location: {person.location}")
```
### Company Information
```python
class CompanyInfo(BaseModel):
name: str
founded_year: int = Field(ge=1800, le=2025)
industry: str
headquarters: str
employees: int = Field(gt=0)
revenue: Optional[str] = None
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
generator = outlines.generate.json(model, CompanyInfo)
text = """
Tesla, Inc. was founded in 2003 and operates primarily in the automotive
and energy industries. The company is headquartered in Austin, Texas,
and employs approximately 140,000 people worldwide.
"""
company = generator(f"Extract company information:\n{text}\n\nCompany:")
print(f"Company: {company.name}")
print(f"Founded: {company.founded_year}")
print(f"Industry: {company.industry}")
print(f"HQ: {company.headquarters}")
print(f"Employees: {company.employees:,}")
```
### Product Specifications
```python
class ProductSpec(BaseModel):
name: str
brand: str
price: float = Field(gt=0)
dimensions: str
weight: str
features: list[str]
rating: Optional[float] = Field(None, ge=0, le=5)
generator = outlines.generate.json(model, ProductSpec)
text = """
The Apple iPhone 15 Pro is priced at $999. It measures 146.6 x 70.6 x 8.25 mm
and weighs 187 grams. Key features include the A17 Pro chip, titanium design,
action button, and USB-C port. It has an average customer rating of 4.5 stars.
"""
product = generator(f"Extract product specifications:\n{text}\n\nProduct:")
print(f"Product: {product.brand} {product.name}")
print(f"Price: ${product.price}")
print(f"Features: {', '.join(product.features)}")
```
## Classification Systems
### Sentiment Analysis
```python
from typing import Literal
from enum import Enum
class Sentiment(str, Enum):
VERY_POSITIVE = "very_positive"
POSITIVE = "positive"
NEUTRAL = "neutral"
NEGATIVE = "negative"
VERY_NEGATIVE = "very_negative"
class SentimentAnalysis(BaseModel):
text: str
sentiment: Sentiment
confidence: float = Field(ge=0.0, le=1.0)
aspects: list[str] # What aspects were mentioned
reasoning: str
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, SentimentAnalysis)
review = """
This product completely exceeded my expectations! The build quality is
outstanding, and customer service was incredibly helpful. My only minor
complaint is the packaging could be better.
"""
result = generator(f"Analyze sentiment:\n{review}\n\nAnalysis:")
print(f"Sentiment: {result.sentiment.value}")
print(f"Confidence: {result.confidence:.2%}")
print(f"Aspects: {', '.join(result.aspects)}")
print(f"Reasoning: {result.reasoning}")
```
### Content Classification
```python
class Category(str, Enum):
TECHNOLOGY = "technology"
BUSINESS = "business"
SCIENCE = "science"
POLITICS = "politics"
ENTERTAINMENT = "entertainment"
SPORTS = "sports"
HEALTH = "health"
class ArticleClassification(BaseModel):
primary_category: Category
secondary_categories: list[Category]
keywords: list[str] = Field(min_items=3, max_items=10)
target_audience: Literal["general", "expert", "beginner"]
reading_level: Literal["elementary", "intermediate", "advanced"]
generator = outlines.generate.json(model, ArticleClassification)
article = """
Apple announced groundbreaking advancements in its AI capabilities with the
release of iOS 18. The new features leverage machine learning to significantly
improve battery life and overall device performance. Industry analysts predict
this will strengthen Apple's position in the competitive smartphone market.
"""
classification = generator(f"Classify article:\n{article}\n\nClassification:")
print(f"Primary: {classification.primary_category.value}")
print(f"Secondary: {[c.value for c in classification.secondary_categories]}")
print(f"Keywords: {classification.keywords}")
print(f"Audience: {classification.target_audience}")
```
### Intent Recognition
```python
class Intent(str, Enum):
QUESTION = "question"
COMPLAINT = "complaint"
REQUEST = "request"
FEEDBACK = "feedback"
CANCEL = "cancel"
UPGRADE = "upgrade"
class UserMessage(BaseModel):
original_message: str
intent: Intent
urgency: Literal["low", "medium", "high", "critical"]
department: Literal["support", "sales", "billing", "technical"]
sentiment: Literal["positive", "neutral", "negative"]
action_required: bool
summary: str
generator = outlines.generate.json(model, UserMessage)
message = """
I've been charged twice for my subscription this month! This is the third
time this has happened. I need someone to fix this immediately and refund
the extra charge. Very disappointed with this service.
"""
result = generator(f"Analyze message:\n{message}\n\nAnalysis:")
print(f"Intent: {result.intent.value}")
print(f"Urgency: {result.urgency}")
print(f"Route to: {result.department}")
print(f"Action required: {result.action_required}")
print(f"Summary: {result.summary}")
```
## Form Processing
### Job Application
```python
class Education(BaseModel):
degree: str
field: str
institution: str
year: int
class Experience(BaseModel):
title: str
company: str
duration: str
responsibilities: list[str]
class JobApplication(BaseModel):
full_name: str
email: str
phone: str
education: list[Education]
experience: list[Experience]
skills: list[str]
availability: str
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
generator = outlines.generate.json(model, JobApplication)
resume_text = """
John Smith
Email: john.smith@email.com | Phone: 555-0123
EDUCATION
- BS in Computer Science, MIT, 2018
- MS in Artificial Intelligence, Stanford, 2020
EXPERIENCE
Software Engineer, Google (2020-2023)
- Developed ML pipelines for search ranking
- Led team of 5 engineers
- Improved search quality by 15%
SKILLS: Python, Machine Learning, TensorFlow, System Design
AVAILABILITY: Immediate
"""
application = generator(f"Extract job application:\n{resume_text}\n\nApplication:")
print(f"Applicant: {application.full_name}")
print(f"Email: {application.email}")
print(f"Education: {len(application.education)} degrees")
for edu in application.education:
print(f" - {edu.degree} in {edu.field}, {edu.institution} ({edu.year})")
print(f"Experience: {len(application.experience)} positions")
```
### Invoice Processing
```python
class InvoiceItem(BaseModel):
description: str
quantity: int = Field(gt=0)
unit_price: float = Field(gt=0)
total: float = Field(gt=0)
class Invoice(BaseModel):
invoice_number: str
date: str = Field(pattern=r"\d{4}-\d{2}-\d{2}")
vendor: str
customer: str
items: list[InvoiceItem]
subtotal: float = Field(gt=0)
tax: float = Field(ge=0)
total: float = Field(gt=0)
generator = outlines.generate.json(model, Invoice)
invoice_text = """
INVOICE #INV-2024-001
Date: 2024-01-15
From: Acme Corp
To: Smith & Co
Items:
- Widget A: 10 units @ $50.00 = $500.00
- Widget B: 5 units @ $75.00 = $375.00
- Service Fee: 1 @ $100.00 = $100.00
Subtotal: $975.00
Tax (8%): $78.00
TOTAL: $1,053.00
"""
invoice = generator(f"Extract invoice:\n{invoice_text}\n\nInvoice:")
print(f"Invoice: {invoice.invoice_number}")
print(f"From: {invoice.vendor} → To: {invoice.customer}")
print(f"Items: {len(invoice.items)}")
for item in invoice.items:
print(f" - {item.description}: {item.quantity} × ${item.unit_price} = ${item.total}")
print(f"Total: ${invoice.total}")
```
### Survey Responses
```python
class SurveyResponse(BaseModel):
respondent_id: str
completion_date: str
satisfaction: Literal[1, 2, 3, 4, 5]
would_recommend: bool
favorite_features: list[str]
improvement_areas: list[str]
additional_comments: Optional[str] = None
generator = outlines.generate.json(model, SurveyResponse)
survey_text = """
Survey ID: RESP-12345
Completed: 2024-01-20
How satisfied are you with our product? 4 out of 5
Would you recommend to a friend? Yes
What features do you like most?
- Fast performance
- Easy to use
- Great customer support
What could we improve?
- Better documentation
- More integrations
Additional feedback: Overall great product, keep up the good work!
"""
response = generator(f"Extract survey response:\n{survey_text}\n\nResponse:")
print(f"Respondent: {response.respondent_id}")
print(f"Satisfaction: {response.satisfaction}/5")
print(f"Would recommend: {response.would_recommend}")
print(f"Favorite features: {response.favorite_features}")
print(f"Improvement areas: {response.improvement_areas}")
```
## Multi-Entity Extraction
### News Article Entities
```python
class Person(BaseModel):
name: str
role: Optional[str] = None
affiliation: Optional[str] = None
class Organization(BaseModel):
name: str
type: Optional[str] = None
class Location(BaseModel):
name: str
type: Literal["city", "state", "country", "region"]
class Event(BaseModel):
name: str
date: Optional[str] = None
location: Optional[str] = None
class ArticleEntities(BaseModel):
people: list[Person]
organizations: list[Organization]
locations: list[Location]
events: list[Event]
dates: list[str]
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
generator = outlines.generate.json(model, ArticleEntities)
article = """
Apple CEO Tim Cook met with Microsoft CEO Satya Nadella at Microsoft
headquarters in Redmond, Washington on September 15, 2024, to discuss
potential collaboration opportunities. The meeting was attended by executives
from both companies and focused on AI integration strategies. Apple's
Cupertino offices will host a follow-up meeting on October 20, 2024.
"""
entities = generator(f"Extract all entities:\n{article}\n\nEntities:")
print("People:")
for person in entities.people:
print(f" - {person.name} ({person.role}) @ {person.affiliation}")
print("\nOrganizations:")
for org in entities.organizations:
print(f" - {org.name} ({org.type})")
print("\nLocations:")
for loc in entities.locations:
print(f" - {loc.name} ({loc.type})")
print("\nEvents:")
for event in entities.events:
print(f" - {event.name} on {event.date}")
```
### Document Metadata
```python
class Author(BaseModel):
name: str
email: Optional[str] = None
affiliation: Optional[str] = None
class Reference(BaseModel):
title: str
authors: list[str]
year: int
source: str
class DocumentMetadata(BaseModel):
title: str
authors: list[Author]
abstract: str
keywords: list[str]
publication_date: str
journal: str
doi: Optional[str] = None
references: list[Reference]
generator = outlines.generate.json(model, DocumentMetadata)
paper = """
Title: Advances in Neural Machine Translation
Authors:
- Dr. Jane Smith (jane@university.edu), MIT
- Prof. John Doe (jdoe@stanford.edu), Stanford University
Abstract: This paper presents novel approaches to neural machine translation
using transformer architectures. We demonstrate significant improvements in
translation quality across multiple language pairs.
Keywords: Neural Networks, Machine Translation, Transformers, NLP
Published: Journal of AI Research, 2024-03-15
DOI: 10.1234/jair.2024.001
References:
1. "Attention Is All You Need" by Vaswani et al., 2017, NeurIPS
2. "BERT: Pre-training of Deep Bidirectional Transformers" by Devlin et al., 2019, NAACL
"""
metadata = generator(f"Extract document metadata:\n{paper}\n\nMetadata:")
print(f"Title: {metadata.title}")
print(f"Authors: {', '.join(a.name for a in metadata.authors)}")
print(f"Keywords: {', '.join(metadata.keywords)}")
print(f"References: {len(metadata.references)}")
```
## Code Generation
### Python Function Generation
```python
class Parameter(BaseModel):
name: str = Field(pattern=r"^[a-z_][a-z0-9_]*$")
type_hint: str
default: Optional[str] = None
class PythonFunction(BaseModel):
function_name: str = Field(pattern=r"^[a-z_][a-z0-9_]*$")
parameters: list[Parameter]
return_type: str
docstring: str
body: list[str] # Lines of code
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, PythonFunction)
spec = "Create a function to calculate the factorial of a number"
func = generator(f"Generate Python function:\n{spec}\n\nFunction:")
print(f"def {func.function_name}(", end="")
print(", ".join(f"{p.name}: {p.type_hint}" for p in func.parameters), end="")
print(f") -> {func.return_type}:")
print(f' """{func.docstring}"""')
for line in func.body:
print(f" {line}")
```
### SQL Query Generation
```python
class SQLQuery(BaseModel):
query_type: Literal["SELECT", "INSERT", "UPDATE", "DELETE"]
select_columns: Optional[list[str]] = None
from_tables: list[str]
joins: Optional[list[str]] = None
where_conditions: Optional[list[str]] = None
group_by: Optional[list[str]] = None
order_by: Optional[list[str]] = None
limit: Optional[int] = None
generator = outlines.generate.json(model, SQLQuery)
request = "Get top 10 users who made purchases in the last 30 days, ordered by total spent"
sql = generator(f"Generate SQL query:\n{request}\n\nQuery:")
print(f"Query type: {sql.query_type}")
print(f"SELECT {', '.join(sql.select_columns)}")
print(f"FROM {', '.join(sql.from_tables)}")
if sql.joins:
for join in sql.joins:
print(f" {join}")
if sql.where_conditions:
print(f"WHERE {' AND '.join(sql.where_conditions)}")
if sql.order_by:
print(f"ORDER BY {', '.join(sql.order_by)}")
if sql.limit:
print(f"LIMIT {sql.limit}")
```
### API Endpoint Spec
```python
class Parameter(BaseModel):
name: str
type: str
required: bool
description: str
class APIEndpoint(BaseModel):
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
path: str
description: str
parameters: list[Parameter]
request_body: Optional[dict] = None
response_schema: dict
status_codes: dict[int, str]
generator = outlines.generate.json(model, APIEndpoint)
spec = "Create user endpoint"
endpoint = generator(f"Generate API endpoint:\n{spec}\n\nEndpoint:")
print(f"{endpoint.method} {endpoint.path}")
print(f"Description: {endpoint.description}")
print("\nParameters:")
for param in endpoint.parameters:
req = "required" if param.required else "optional"
print(f" - {param.name} ({param.type}, {req}): {param.description}")
```
## Batch Processing
### Parallel Extraction
```python
def batch_extract(texts: list[str], schema: type[BaseModel], model_name: str):
"""Extract structured data from multiple texts."""
model = outlines.models.transformers(model_name)
generator = outlines.generate.json(model, schema)
results = []
for i, text in enumerate(texts):
print(f"Processing {i+1}/{len(texts)}...", end="\r")
result = generator(f"Extract:\n{text}\n\nData:")
results.append(result)
return results
class Product(BaseModel):
name: str
price: float
category: str
texts = [
"iPhone 15 Pro costs $999 in Electronics",
"Running Shoes are $89.99 in Sports",
"Coffee Maker priced at $49.99 in Home & Kitchen"
]
products = batch_extract(texts, Product, "microsoft/Phi-3-mini-4k-instruct")
for product in products:
print(f"{product.name}: ${product.price} ({product.category})")
```
### CSV Processing
```python
import csv
def process_csv(csv_file: str, schema: type[BaseModel]):
"""Process CSV file and extract structured data."""
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, schema)
results = []
with open(csv_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
text = " | ".join(f"{k}: {v}" for k, v in row.items())
result = generator(f"Extract:\n{text}\n\nData:")
results.append(result)
return results
class Customer(BaseModel):
name: str
email: str
tier: Literal["basic", "premium", "enterprise"]
mrr: float
# customers = process_csv("customers.csv", Customer)
```
## Production Patterns
### Error Handling
```python
from pydantic import ValidationError
def safe_extract(text: str, schema: type[BaseModel], retries: int = 3):
"""Extract with error handling and retries."""
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, schema)
for attempt in range(retries):
try:
result = generator(f"Extract:\n{text}\n\nData:")
return result
except ValidationError as e:
print(f"Attempt {attempt + 1} failed: {e}")
if attempt == retries - 1:
raise
except Exception as e:
print(f"Unexpected error: {e}")
if attempt == retries - 1:
raise
return None
```
### Caching
```python
from functools import lru_cache
import hashlib
@lru_cache(maxsize=1000)
def cached_extract(text_hash: str, schema_name: str):
"""Cache extraction results."""
# This would be called with actual extraction logic
pass
def extract_with_cache(text: str, schema: type[BaseModel]):
"""Extract with caching."""
text_hash = hashlib.md5(text.encode()).hexdigest()
schema_name = schema.__name__
cached_result = cached_extract(text_hash, schema_name)
if cached_result:
return cached_result
# Perform actual extraction
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, schema)
result = generator(f"Extract:\n{text}\n\nData:")
return result
```
### Monitoring
```python
import time
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def monitored_extract(text: str, schema: type[BaseModel]):
"""Extract with monitoring and logging."""
start_time = time.time()
try:
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, schema)
result = generator(f"Extract:\n{text}\n\nData:")
elapsed = time.time() - start_time
logger.info(f"Extraction succeeded in {elapsed:.2f}s")
logger.info(f"Input length: {len(text)} chars")
return result
except Exception as e:
elapsed = time.time() - start_time
logger.error(f"Extraction failed after {elapsed:.2f}s: {e}")
raise
```
### Rate Limiting
```python
import time
from threading import Lock
class RateLimiter:
def __init__(self, max_requests: int, time_window: int):
self.max_requests = max_requests
self.time_window = time_window
self.requests = []
self.lock = Lock()
def wait_if_needed(self):
with self.lock:
now = time.time()
# Remove old requests
self.requests = [r for r in self.requests if now - r < self.time_window]
if len(self.requests) >= self.max_requests:
sleep_time = self.time_window - (now - self.requests[0])
time.sleep(sleep_time)
self.requests = []
self.requests.append(now)
def rate_limited_extract(texts: list[str], schema: type[BaseModel]):
"""Extract with rate limiting."""
limiter = RateLimiter(max_requests=10, time_window=60) # 10 req/min
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, schema)
results = []
for text in texts:
limiter.wait_if_needed()
result = generator(f"Extract:\n{text}\n\nData:")
results.append(result)
return results
```
## Resources
- **Outlines Documentation**: https://outlines-dev.github.io/outlines
- **Pydantic Documentation**: https://docs.pydantic.dev
- **GitHub Examples**: https://github.com/outlines-dev/outlines/tree/main/examples

View file

@ -0,0 +1,652 @@
# Comprehensive JSON Generation Guide
Complete guide to JSON generation with Outlines using Pydantic models and JSON schemas.
## Table of Contents
- Pydantic Models
- JSON Schema Support
- Advanced Patterns
- Nested Structures
- Complex Types
- Validation
- Performance Optimization
## Pydantic Models
### Basic Models
```python
from pydantic import BaseModel
import outlines
class User(BaseModel):
name: str
age: int
email: str
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, User)
user = generator("Generate user: Alice, 25, alice@example.com")
print(user.name) # "Alice"
print(user.age) # 25
print(user.email) # "alice@example.com"
```
###
Field Constraints
```python
from pydantic import BaseModel, Field
class Product(BaseModel):
name: str = Field(min_length=1, max_length=100)
price: float = Field(gt=0, description="Price in USD")
discount: float = Field(ge=0, le=100, description="Discount percentage")
quantity: int = Field(ge=0, description="Available quantity")
sku: str = Field(pattern=r"^[A-Z]{3}-\d{6}$")
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, Product)
product = generator("Generate product: iPhone 15, $999")
# All fields guaranteed to meet constraints
```
**Available Constraints:**
- `min_length`, `max_length`: String length
- `gt`, `ge`, `lt`, `le`: Numeric comparisons
- `multiple_of`: Number must be multiple of value
- `pattern`: Regex pattern for strings
- `min_items`, `max_items`: List length
### Optional Fields
```python
from typing import Optional
class Article(BaseModel):
title: str # Required
author: Optional[str] = None # Optional
published_date: Optional[str] = None # Optional
tags: list[str] = [] # Default empty list
view_count: int = 0 # Default value
generator = outlines.generate.json(model, Article)
# Can generate even if optional fields missing
article = generator("Title: Introduction to AI")
print(article.author) # None (not provided)
print(article.tags) # [] (default)
```
### Default Values
```python
class Config(BaseModel):
debug: bool = False
max_retries: int = 3
timeout: float = 30.0
log_level: str = "INFO"
# Generator uses defaults when not specified
generator = outlines.generate.json(model, Config)
config = generator("Generate config with debug enabled")
print(config.debug) # True (from prompt)
print(config.timeout) # 30.0 (default)
```
## Enums and Literals
### Enum Fields
```python
from enum import Enum
class Status(str, Enum):
PENDING = "pending"
APPROVED = "approved"
REJECTED = "rejected"
CANCELLED = "cancelled"
class Application(BaseModel):
applicant_name: str
status: Status # Must be one of enum values
submitted_date: str
generator = outlines.generate.json(model, Application)
app = generator("Generate application for John Doe")
print(app.status) # Status.PENDING (or one of the enum values)
print(type(app.status)) # <enum 'Status'>
```
### Literal Types
```python
from typing import Literal
class Task(BaseModel):
title: str
priority: Literal["low", "medium", "high", "critical"]
status: Literal["todo", "in_progress", "done"]
assigned_to: str
generator = outlines.generate.json(model, Task)
task = generator("Create high priority task: Fix bug")
print(task.priority) # One of: "low", "medium", "high", "critical"
```
### Multiple Choice Fields
```python
class Survey(BaseModel):
question: str
answer: Literal["strongly_disagree", "disagree", "neutral", "agree", "strongly_agree"]
confidence: Literal["low", "medium", "high"]
generator = outlines.generate.json(model, Survey)
survey = generator("Rate: 'I enjoy using this product'")
```
## Nested Structures
### Nested Models
```python
class Address(BaseModel):
street: str
city: str
state: str
zip_code: str
country: str = "USA"
class Person(BaseModel):
name: str
age: int
email: str
address: Address # Nested model
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, Person)
prompt = """
Extract person:
Name: Alice Johnson
Age: 28
Email: alice@example.com
Address: 123 Main St, Boston, MA, 02101
"""
person = generator(prompt)
print(person.name) # "Alice Johnson"
print(person.address.city) # "Boston"
print(person.address.state) # "MA"
```
### Deep Nesting
```python
class Coordinates(BaseModel):
latitude: float
longitude: float
class Location(BaseModel):
name: str
coordinates: Coordinates
class Event(BaseModel):
title: str
date: str
location: Location
generator = outlines.generate.json(model, Event)
event = generator("Generate event: Tech Conference in San Francisco")
print(event.title) # "Tech Conference"
print(event.location.name) # "San Francisco"
print(event.location.coordinates.latitude) # 37.7749
```
### Lists of Nested Models
```python
class Item(BaseModel):
name: str
quantity: int
price: float
class Order(BaseModel):
order_id: str
customer: str
items: list[Item] # List of nested models
total: float
generator = outlines.generate.json(model, Order)
prompt = """
Generate order for John:
- 2x Widget ($10 each)
- 3x Gadget ($15 each)
Order ID: ORD-001
"""
order = generator(prompt)
print(f"Order ID: {order.order_id}")
for item in order.items:
print(f"- {item.quantity}x {item.name} @ ${item.price}")
print(f"Total: ${order.total}")
```
## Complex Types
### Union Types
```python
from typing import Union
class TextContent(BaseModel):
type: Literal["text"]
content: str
class ImageContent(BaseModel):
type: Literal["image"]
url: str
caption: str
class Post(BaseModel):
title: str
content: Union[TextContent, ImageContent] # Either type
generator = outlines.generate.json(model, Post)
# Can generate either text or image content
post = generator("Generate blog post with image")
if post.content.type == "text":
print(post.content.content)
elif post.content.type == "image":
print(post.content.url)
```
### Lists and Arrays
```python
class Article(BaseModel):
title: str
authors: list[str] # List of strings
tags: list[str]
sections: list[dict[str, str]] # List of dicts
related_ids: list[int]
generator = outlines.generate.json(model, Article)
article = generator("Generate article about AI")
print(article.authors) # ["Alice", "Bob"]
print(article.tags) # ["AI", "Machine Learning", "Technology"]
```
### Dictionaries
```python
class Metadata(BaseModel):
title: str
properties: dict[str, str] # String keys and values
counts: dict[str, int] # String keys, int values
settings: dict[str, Union[str, int, bool]] # Mixed value types
generator = outlines.generate.json(model, Metadata)
meta = generator("Generate metadata")
print(meta.properties) # {"author": "Alice", "version": "1.0"}
print(meta.counts) # {"views": 1000, "likes": 50}
```
### Any Type (Use Sparingly)
```python
from typing import Any
class FlexibleData(BaseModel):
name: str
structured_field: str
flexible_field: Any # Can be anything
# Note: Any reduces type safety, use only when necessary
generator = outlines.generate.json(model, FlexibleData)
```
## JSON Schema Support
### Direct Schema Usage
```python
import outlines
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
# Define JSON schema
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer", "minimum": 0, "maximum": 120},
"email": {"type": "string", "format": "email"}
},
"required": ["name", "age", "email"]
}
# Generate from schema
generator = outlines.generate.json(model, schema)
result = generator("Generate person: Alice, 25, alice@example.com")
print(result) # Valid JSON matching schema
```
### Schema from Pydantic
```python
class User(BaseModel):
name: str
age: int
email: str
# Get JSON schema from Pydantic model
schema = User.model_json_schema()
print(schema)
# {
# "type": "object",
# "properties": {
# "name": {"type": "string"},
# "age": {"type": "integer"},
# "email": {"type": "string"}
# },
# "required": ["name", "age", "email"]
# }
# Both approaches equivalent:
generator1 = outlines.generate.json(model, User)
generator2 = outlines.generate.json(model, schema)
```
## Advanced Patterns
### Conditional Fields
```python
class Order(BaseModel):
order_type: Literal["standard", "express"]
delivery_date: str
express_fee: Optional[float] = None # Only for express orders
generator = outlines.generate.json(model, Order)
# Express order
order1 = generator("Create express order for tomorrow")
print(order1.express_fee) # 25.0
# Standard order
order2 = generator("Create standard order")
print(order2.express_fee) # None
```
### Recursive Models
```python
from typing import Optional, List
class TreeNode(BaseModel):
value: str
children: Optional[List['TreeNode']] = None
# Enable forward references
TreeNode.model_rebuild()
generator = outlines.generate.json(model, TreeNode)
tree = generator("Generate file tree with subdirectories")
print(tree.value) # "root"
print(tree.children[0].value) # "subdir1"
```
### Model with Validation
```python
from pydantic import field_validator
class DateRange(BaseModel):
start_date: str
end_date: str
@field_validator('end_date')
def end_after_start(cls, v, info):
"""Ensure end_date is after start_date."""
if 'start_date' in info.data:
from datetime import datetime
start = datetime.strptime(info.data['start_date'], '%Y-%m-%d')
end = datetime.strptime(v, '%Y-%m-%d')
if end < start:
raise ValueError('end_date must be after start_date')
return v
generator = outlines.generate.json(model, DateRange)
# Validation happens after generation
```
## Multiple Objects
### Generate List of Objects
```python
class Person(BaseModel):
name: str
age: int
class Team(BaseModel):
team_name: str
members: list[Person]
generator = outlines.generate.json(model, Team)
team = generator("Generate engineering team with 5 members")
print(f"Team: {team.team_name}")
for member in team.members:
print(f"- {member.name}, {member.age}")
```
### Batch Generation
```python
def generate_batch(prompts: list[str], schema: type[BaseModel]):
"""Generate structured outputs for multiple prompts."""
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, schema)
results = []
for prompt in prompts:
result = generator(prompt)
results.append(result)
return results
class Product(BaseModel):
name: str
price: float
prompts = [
"Product: iPhone 15, $999",
"Product: MacBook Pro, $2499",
"Product: AirPods, $179"
]
products = generate_batch(prompts, Product)
for product in products:
print(f"{product.name}: ${product.price}")
```
## Performance Optimization
### Caching Generators
```python
from functools import lru_cache
@lru_cache(maxsize=10)
def get_generator(model_name: str, schema_hash: int):
"""Cache generators for reuse."""
model = outlines.models.transformers(model_name)
return outlines.generate.json(model, schema)
# First call: creates generator
gen1 = get_generator("microsoft/Phi-3-mini-4k-instruct", hash(User))
# Second call: returns cached generator (fast!)
gen2 = get_generator("microsoft/Phi-3-mini-4k-instruct", hash(User))
```
### Batch Processing
```python
# Process multiple items efficiently
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
generator = outlines.generate.json(model, User)
texts = ["User: Alice, 25", "User: Bob, 30", "User: Carol, 35"]
# Reuse generator (model stays loaded)
users = [generator(text) for text in texts]
```
### Minimize Schema Complexity
```python
# ✅ Good: Simple, flat structure (faster)
class SimplePerson(BaseModel):
name: str
age: int
city: str
# ⚠️ Slower: Deep nesting
class ComplexPerson(BaseModel):
personal_info: PersonalInfo
address: Address
employment: Employment
# ... many nested levels
```
## Error Handling
### Handle Missing Fields
```python
from pydantic import ValidationError
class User(BaseModel):
name: str
age: int
email: str
try:
user = generator("Generate user") # May not include all fields
except ValidationError as e:
print(f"Validation error: {e}")
# Handle gracefully
```
### Fallback with Optional Fields
```python
class RobustUser(BaseModel):
name: str # Required
age: Optional[int] = None # Optional
email: Optional[str] = None # Optional
# More likely to succeed even with incomplete data
user = generator("Generate user: Alice")
print(user.name) # "Alice"
print(user.age) # None (not provided)
```
## Best Practices
### 1. Use Specific Types
```python
# ✅ Good: Specific types
class Product(BaseModel):
name: str
price: float # Not Any or str
quantity: int # Not str
in_stock: bool # Not int
# ❌ Bad: Generic types
class Product(BaseModel):
name: Any
price: str # Should be float
quantity: str # Should be int
```
### 2. Add Descriptions
```python
# ✅ Good: Clear descriptions
class Article(BaseModel):
title: str = Field(description="Article title, 10-100 characters")
content: str = Field(description="Main article content in paragraphs")
tags: list[str] = Field(description="List of relevant topic tags")
# Descriptions help the model understand expected output
```
### 3. Use Constraints
```python
# ✅ Good: With constraints
class Age(BaseModel):
value: int = Field(ge=0, le=120, description="Age in years")
# ❌ Bad: No constraints
class Age(BaseModel):
value: int # Could be negative or > 120
```
### 4. Prefer Enums Over Strings
```python
# ✅ Good: Enum for fixed set
class Priority(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
class Task(BaseModel):
priority: Priority # Guaranteed valid
# ❌ Bad: Free-form string
class Task(BaseModel):
priority: str # Could be "urgent", "ASAP", "!!", etc.
```
### 5. Test Your Models
```python
# Test models work as expected
def test_product_model():
product = Product(
name="Test Product",
price=19.99,
quantity=10,
in_stock=True
)
assert product.price == 19.99
assert isinstance(product, Product)
# Run tests before using in production
```
## Resources
- **Pydantic Docs**: https://docs.pydantic.dev
- **JSON Schema**: https://json-schema.org
- **Outlines GitHub**: https://github.com/outlines-dev/outlines

View file

@ -0,0 +1,190 @@
---
name: tensorrt-llm
description: Optimizes LLM inference with NVIDIA TensorRT for maximum throughput and lowest latency. Use for production deployment on NVIDIA GPUs (A100/H100), when you need 10-100x faster inference than PyTorch, or for serving models with quantization (FP8/INT4), in-flight batching, and multi-GPU scaling.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [tensorrt-llm, torch]
metadata:
hermes:
tags: [Inference Serving, TensorRT-LLM, NVIDIA, Inference Optimization, High Throughput, Low Latency, Production, FP8, INT4, In-Flight Batching, Multi-GPU]
---
# TensorRT-LLM
NVIDIA's open-source library for optimizing LLM inference with state-of-the-art performance on NVIDIA GPUs.
## When to use TensorRT-LLM
**Use TensorRT-LLM when:**
- Deploying on NVIDIA GPUs (A100, H100, GB200)
- Need maximum throughput (24,000+ tokens/sec on Llama 3)
- Require low latency for real-time applications
- Working with quantized models (FP8, INT4, FP4)
- Scaling across multiple GPUs or nodes
**Use vLLM instead when:**
- Need simpler setup and Python-first API
- Want PagedAttention without TensorRT compilation
- Working with AMD GPUs or non-NVIDIA hardware
**Use llama.cpp instead when:**
- Deploying on CPU or Apple Silicon
- Need edge deployment without NVIDIA GPUs
- Want simpler GGUF quantization format
## Quick start
### Installation
```bash
# Docker (recommended)
docker pull nvidia/tensorrt_llm:latest
# pip install
pip install tensorrt_llm==1.2.0rc3
# Requires CUDA 13.0.0, TensorRT 10.13.2, Python 3.10-3.12
```
### Basic inference
```python
from tensorrt_llm import LLM, SamplingParams
# Initialize model
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
# Configure sampling
sampling_params = SamplingParams(
max_tokens=100,
temperature=0.7,
top_p=0.9
)
# Generate
prompts = ["Explain quantum computing"]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.text)
```
### Serving with trtllm-serve
```bash
# Start server (automatic model download and compilation)
trtllm-serve meta-llama/Meta-Llama-3-8B \
--tp_size 4 \ # Tensor parallelism (4 GPUs)
--max_batch_size 256 \
--max_num_tokens 4096
# Client request
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3-8B",
"messages": [{"role": "user", "content": "Hello!"}],
"temperature": 0.7,
"max_tokens": 100
}'
```
## Key features
### Performance optimizations
- **In-flight batching**: Dynamic batching during generation
- **Paged KV cache**: Efficient memory management
- **Flash Attention**: Optimized attention kernels
- **Quantization**: FP8, INT4, FP4 for 2-4× faster inference
- **CUDA graphs**: Reduced kernel launch overhead
### Parallelism
- **Tensor parallelism (TP)**: Split model across GPUs
- **Pipeline parallelism (PP)**: Layer-wise distribution
- **Expert parallelism**: For Mixture-of-Experts models
- **Multi-node**: Scale beyond single machine
### Advanced features
- **Speculative decoding**: Faster generation with draft models
- **LoRA serving**: Efficient multi-adapter deployment
- **Disaggregated serving**: Separate prefill and generation
## Common patterns
### Quantized model (FP8)
```python
from tensorrt_llm import LLM
# Load FP8 quantized model (2× faster, 50% memory)
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
dtype="fp8",
max_num_tokens=8192
)
# Inference same as before
outputs = llm.generate(["Summarize this article..."])
```
### Multi-GPU deployment
```python
# Tensor parallelism across 8 GPUs
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
tensor_parallel_size=8,
dtype="fp8"
)
```
### Batch inference
```python
# Process 100 prompts efficiently
prompts = [f"Question {i}: ..." for i in range(100)]
outputs = llm.generate(
prompts,
sampling_params=SamplingParams(max_tokens=200)
)
# Automatic in-flight batching for maximum throughput
```
## Performance benchmarks
**Meta Llama 3-8B** (H100 GPU):
- Throughput: 24,000 tokens/sec
- Latency: ~10ms per token
- vs PyTorch: **100× faster**
**Llama 3-70B** (8× A100 80GB):
- FP8 quantization: 2× faster than FP16
- Memory: 50% reduction with FP8
## Supported models
- **LLaMA family**: Llama 2, Llama 3, CodeLlama
- **GPT family**: GPT-2, GPT-J, GPT-NeoX
- **Qwen**: Qwen, Qwen2, QwQ
- **DeepSeek**: DeepSeek-V2, DeepSeek-V3
- **Mixtral**: Mixtral-8x7B, Mixtral-8x22B
- **Vision**: LLaVA, Phi-3-vision
- **100+ models** on HuggingFace
## References
- **[Optimization Guide](references/optimization.md)** - Quantization, batching, KV cache tuning
- **[Multi-GPU Setup](references/multi-gpu.md)** - Tensor/pipeline parallelism, multi-node
- **[Serving Guide](references/serving.md)** - Production deployment, monitoring, autoscaling
## Resources
- **Docs**: https://nvidia.github.io/TensorRT-LLM/
- **GitHub**: https://github.com/NVIDIA/TensorRT-LLM
- **Models**: https://huggingface.co/models?library=tensorrt_llm

View file

@ -0,0 +1,298 @@
# Multi-GPU Deployment Guide
Comprehensive guide to scaling TensorRT-LLM across multiple GPUs and nodes.
## Parallelism Strategies
### Tensor Parallelism (TP)
**What it does**: Splits model layers across GPUs horizontally.
**Use case**:
- Model fits in total GPU memory but not single GPU
- Need low latency (single forward pass)
- GPUs on same node (NVLink required for best performance)
**Example** (Llama 3-70B on 4× A100):
```python
from tensorrt_llm import LLM
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
tensor_parallel_size=4, # Split across 4 GPUs
dtype="fp16"
)
# Model automatically sharded across GPUs
# Single forward pass, low latency
```
**Performance**:
- Latency: ~Same as single GPU
- Throughput: 4× higher (4 GPUs)
- Communication: High (activations synced every layer)
### Pipeline Parallelism (PP)
**What it does**: Splits model layers across GPUs vertically (layer-wise).
**Use case**:
- Very large models (175B+)
- Can tolerate higher latency
- GPUs across multiple nodes
**Example** (Llama 3-405B on 8× H100):
```python
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
tensor_parallel_size=4, # TP=4 within nodes
pipeline_parallel_size=2, # PP=2 across nodes
dtype="fp8"
)
# Total: 8 GPUs (4×2)
# Layers 0-40: Node 1 (4 GPUs with TP)
# Layers 41-80: Node 2 (4 GPUs with TP)
```
**Performance**:
- Latency: Higher (sequential through pipeline)
- Throughput: High with micro-batching
- Communication: Lower than TP
### Expert Parallelism (EP)
**What it does**: Distributes MoE experts across GPUs.
**Use case**: Mixture-of-Experts models (Mixtral, DeepSeek-V2)
**Example** (Mixtral-8x22B on 8× A100):
```python
llm = LLM(
model="mistralai/Mixtral-8x22B",
tensor_parallel_size=4,
expert_parallel_size=2, # Distribute 8 experts across 2 groups
dtype="fp8"
)
```
## Configuration Examples
### Small model (7-13B) - Single GPU
```python
# Llama 3-8B on 1× A100 80GB
llm = LLM(
model="meta-llama/Meta-Llama-3-8B",
dtype="fp16" # or fp8 for H100
)
```
**Resources**:
- GPU: 1× A100 80GB
- Memory: ~16GB model + 30GB KV cache
- Throughput: 3,000-5,000 tokens/sec
### Medium model (70B) - Multi-GPU same node
```python
# Llama 3-70B on 4× A100 80GB (NVLink)
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
tensor_parallel_size=4,
dtype="fp8" # 70GB → 35GB per GPU
)
```
**Resources**:
- GPU: 4× A100 80GB with NVLink
- Memory: ~35GB per GPU (FP8)
- Throughput: 10,000-15,000 tokens/sec
- Latency: 15-20ms per token
### Large model (405B) - Multi-node
```python
# Llama 3-405B on 2 nodes × 8 H100 = 16 GPUs
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
tensor_parallel_size=8, # TP within each node
pipeline_parallel_size=2, # PP across 2 nodes
dtype="fp8"
)
```
**Resources**:
- GPU: 2 nodes × 8 H100 80GB
- Memory: ~25GB per GPU (FP8)
- Throughput: 20,000-30,000 tokens/sec
- Network: InfiniBand recommended
## Server Deployment
### Single-node multi-GPU
```bash
# Llama 3-70B on 4 GPUs (automatic TP)
trtllm-serve meta-llama/Meta-Llama-3-70B \
--tp_size 4 \
--max_batch_size 256 \
--dtype fp8
# Listens on http://localhost:8000
```
### Multi-node with Ray
```bash
# Node 1 (head node)
ray start --head --port=6379
# Node 2 (worker)
ray start --address='node1:6379'
# Deploy across cluster
trtllm-serve meta-llama/Meta-Llama-3-405B \
--tp_size 8 \
--pp_size 2 \
--num_workers 2 \ # 2 nodes
--dtype fp8
```
### Kubernetes deployment
```yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: tensorrt-llm-llama3-70b
spec:
replicas: 1
template:
spec:
containers:
- name: trtllm
image: nvidia/tensorrt_llm:latest
command:
- trtllm-serve
- meta-llama/Meta-Llama-3-70B
- --tp_size=4
- --max_batch_size=256
resources:
limits:
nvidia.com/gpu: 4 # Request 4 GPUs
```
## Parallelism Decision Tree
```
Model size < 20GB?
├─ YES: Single GPU (no parallelism)
└─ NO: Model size < 80GB?
├─ YES: TP=2 or TP=4 (same node)
└─ NO: Model size < 320GB?
├─ YES: TP=4 or TP=8 (same node, NVLink required)
└─ NO: TP=8 + PP=2 (multi-node)
```
## Communication Optimization
### NVLink vs PCIe
**NVLink** (DGX A100, HGX H100):
- Bandwidth: 600 GB/s (A100), 900 GB/s (H100)
- Ideal for TP (high communication)
- **Recommended for all multi-GPU setups**
**PCIe**:
- Bandwidth: 64 GB/s (PCIe 4.0 x16)
- 10× slower than NVLink
- Avoid TP, use PP instead
### InfiniBand for multi-node
**HDR InfiniBand** (200 Gb/s):
- Required for multi-node TP or PP
- Latency: <1μs
- **Essential for 405B+ models**
## Monitoring Multi-GPU
```python
# Monitor GPU utilization
nvidia-smi dmon -s u
# Monitor memory
nvidia-smi dmon -s m
# Monitor NVLink utilization
nvidia-smi nvlink --status
# TensorRT-LLM built-in metrics
curl http://localhost:8000/metrics
```
**Key metrics**:
- GPU utilization: Target 80-95%
- Memory usage: Should be balanced across GPUs
- NVLink traffic: High for TP, low for PP
- Throughput: Tokens/sec across all GPUs
## Common Issues
### Imbalanced GPU memory
**Symptom**: GPU 0 has 90% memory, GPU 3 has 40%
**Solutions**:
- Verify TP/PP configuration
- Check model sharding (should be equal)
- Restart server to reset state
### Low NVLink utilization
**Symptom**: NVLink bandwidth <100 GB/s with TP=4
**Solutions**:
- Verify NVLink topology: `nvidia-smi topo -m`
- Check for PCIe fallback
- Ensure GPUs are on same NVSwitch
### OOM with multi-GPU
**Solutions**:
- Increase TP size (more GPUs)
- Reduce batch size
- Enable FP8 quantization
- Use pipeline parallelism
## Performance Scaling
### TP Scaling (Llama 3-70B, FP8)
| GPUs | TP Size | Throughput | Latency | Efficiency |
|------|---------|------------|---------|------------|
| 1 | 1 | OOM | - | - |
| 2 | 2 | 6,000 tok/s | 18ms | 85% |
| 4 | 4 | 11,000 tok/s | 16ms | 78% |
| 8 | 8 | 18,000 tok/s | 15ms | 64% |
**Note**: Efficiency drops with more GPUs due to communication overhead.
### PP Scaling (Llama 3-405B, FP8)
| Nodes | TP | PP | Total GPUs | Throughput |
|-------|----|----|------------|------------|
| 1 | 8 | 1 | 8 | OOM |
| 2 | 8 | 2 | 16 | 25,000 tok/s |
| 4 | 8 | 4 | 32 | 45,000 tok/s |
## Best Practices
1. **Prefer TP over PP** when possible (lower latency)
2. **Use NVLink** for all TP deployments
3. **Use InfiniBand** for multi-node deployments
4. **Start with smallest TP** that fits model in memory
5. **Monitor GPU balance** - all GPUs should have similar utilization
6. **Test with benchmark** before production
7. **Use FP8** on H100 for 2× speedup

View file

@ -0,0 +1,242 @@
# TensorRT-LLM Optimization Guide
Comprehensive guide to optimizing LLM inference with TensorRT-LLM.
## Quantization
### FP8 Quantization (Recommended for H100)
**Benefits**:
- 2× faster inference
- 50% memory reduction
- Minimal accuracy loss (<1% perplexity degradation)
**Usage**:
```python
from tensorrt_llm import LLM
# Automatic FP8 quantization
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
dtype="fp8",
quantization="fp8"
)
```
**Performance** (Llama 3-70B on 8× H100):
- FP16: 5,000 tokens/sec
- FP8: **10,000 tokens/sec** (2× speedup)
- Memory: 140GB → 70GB
### INT4 Quantization (Maximum compression)
**Benefits**:
- 4× memory reduction
- 3-4× faster inference
- Fits larger models on same hardware
**Usage**:
```python
# INT4 with AWQ calibration
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
dtype="int4_awq",
quantization="awq"
)
# INT4 with GPTQ calibration
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
dtype="int4_gptq",
quantization="gptq"
)
```
**Trade-offs**:
- Accuracy: 1-3% perplexity increase
- Speed: 3-4× faster than FP16
- Use case: When memory is critical
## In-Flight Batching
**What it does**: Dynamically batches requests during generation instead of waiting for all sequences to finish.
**Configuration**:
```python
# Server configuration
trtllm-serve meta-llama/Meta-Llama-3-8B \
--max_batch_size 256 \ # Maximum concurrent sequences
--max_num_tokens 4096 \ # Total tokens in batch
--enable_chunked_context \ # Split long prompts
--scheduler_policy max_utilization
```
**Performance**:
- Throughput: **4-8× higher** vs static batching
- Latency: Lower P50/P99 for mixed workloads
- GPU utilization: 80-95% vs 40-60%
## Paged KV Cache
**What it does**: Manages KV cache memory like OS manages virtual memory (paging).
**Benefits**:
- 40-60% higher throughput
- No memory fragmentation
- Supports longer sequences
**Configuration**:
```python
# Automatic paged KV cache (default)
llm = LLM(
model="meta-llama/Meta-Llama-3-8B",
kv_cache_free_gpu_mem_fraction=0.9, # Use 90% GPU mem for cache
enable_prefix_caching=True # Cache common prefixes
)
```
## Speculative Decoding
**What it does**: Uses small draft model to predict multiple tokens, verified by target model in parallel.
**Speedup**: 2-3× faster for long generations
**Usage**:
```python
from tensorrt_llm import LLM
# Target model (Llama 3-70B)
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
speculative_model="meta-llama/Meta-Llama-3-8B", # Draft model
num_speculative_tokens=5 # Tokens to predict ahead
)
# Same API, 2-3× faster
outputs = llm.generate(prompts)
```
**Best models for drafting**:
- Target: Llama 3-70B → Draft: Llama 3-8B
- Target: Qwen2-72B → Draft: Qwen2-7B
- Same family, 8-10× smaller
## CUDA Graphs
**What it does**: Reduces kernel launch overhead by recording GPU operations.
**Benefits**:
- 10-20% lower latency
- More stable P99 latency
- Better for small batch sizes
**Configuration** (automatic by default):
```python
llm = LLM(
model="meta-llama/Meta-Llama-3-8B",
enable_cuda_graph=True, # Default: True
cuda_graph_cache_size=2 # Cache 2 graph variants
)
```
## Chunked Context
**What it does**: Splits long prompts into chunks to reduce memory spikes.
**Use case**: Prompts >8K tokens with limited GPU memory
**Configuration**:
```bash
trtllm-serve meta-llama/Meta-Llama-3-8B \
--max_num_tokens 4096 \
--enable_chunked_context \
--max_chunked_prefill_length 2048 # Process 2K tokens at a time
```
## Overlap Scheduling
**What it does**: Overlaps compute and memory operations.
**Benefits**:
- 15-25% higher throughput
- Better GPU utilization
- Default in v1.2.0+
**No configuration needed** - enabled automatically.
## Quantization Comparison Table
| Method | Memory | Speed | Accuracy | Use Case |
|--------|--------|-------|----------|----------|
| FP16 | 1× (baseline) | 1× | Best | High accuracy needed |
| FP8 | 0.5× | 2× | -0.5% ppl | **H100 default** |
| INT4 AWQ | 0.25× | 3-4× | -1.5% ppl | Memory critical |
| INT4 GPTQ | 0.25× | 3-4× | -2% ppl | Maximum speed |
## Tuning Workflow
1. **Start with defaults**:
```python
llm = LLM(model="meta-llama/Meta-Llama-3-70B")
```
2. **Enable FP8** (if H100):
```python
llm = LLM(model="...", dtype="fp8")
```
3. **Tune batch size**:
```python
# Increase until OOM, then reduce 20%
trtllm-serve ... --max_batch_size 256
```
4. **Enable chunked context** (if long prompts):
```bash
--enable_chunked_context --max_chunked_prefill_length 2048
```
5. **Try speculative decoding** (if latency critical):
```python
llm = LLM(model="...", speculative_model="...")
```
## Benchmarking
```bash
# Install benchmark tool
pip install tensorrt_llm[benchmark]
# Run benchmark
python benchmarks/python/benchmark.py \
--model meta-llama/Meta-Llama-3-8B \
--batch_size 64 \
--input_len 128 \
--output_len 256 \
--dtype fp8
```
**Metrics to track**:
- Throughput (tokens/sec)
- Latency P50/P90/P99 (ms)
- GPU memory usage (GB)
- GPU utilization (%)
## Common Issues
**OOM errors**:
- Reduce `max_batch_size`
- Reduce `max_num_tokens`
- Enable INT4 quantization
- Increase `tensor_parallel_size`
**Low throughput**:
- Increase `max_batch_size`
- Enable in-flight batching
- Verify CUDA graphs enabled
- Check GPU utilization
**High latency**:
- Try speculative decoding
- Reduce `max_batch_size` (less queueing)
- Use FP8 instead of FP16

View file

@ -0,0 +1,470 @@
# Production Serving Guide
Comprehensive guide to deploying TensorRT-LLM in production environments.
## Server Modes
### trtllm-serve (Recommended)
**Features**:
- OpenAI-compatible API
- Automatic model download and compilation
- Built-in load balancing
- Prometheus metrics
- Health checks
**Basic usage**:
```bash
trtllm-serve meta-llama/Meta-Llama-3-8B \
--tp_size 1 \
--max_batch_size 256 \
--port 8000
```
**Advanced configuration**:
```bash
trtllm-serve meta-llama/Meta-Llama-3-70B \
--tp_size 4 \
--dtype fp8 \
--max_batch_size 256 \
--max_num_tokens 4096 \
--enable_chunked_context \
--scheduler_policy max_utilization \
--port 8000 \
--api_key $API_KEY # Optional authentication
```
### Python LLM API (For embedding)
```python
from tensorrt_llm import LLM
class LLMService:
def __init__(self):
self.llm = LLM(
model="meta-llama/Meta-Llama-3-8B",
dtype="fp8"
)
def generate(self, prompt, max_tokens=100):
from tensorrt_llm import SamplingParams
params = SamplingParams(
max_tokens=max_tokens,
temperature=0.7
)
outputs = self.llm.generate([prompt], params)
return outputs[0].text
# Use in FastAPI, Flask, etc
from fastapi import FastAPI
app = FastAPI()
service = LLMService()
@app.post("/generate")
def generate(prompt: str):
return {"response": service.generate(prompt)}
```
## OpenAI-Compatible API
### Chat Completions
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3-8B",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Explain quantum computing"}
],
"temperature": 0.7,
"max_tokens": 500,
"stream": false
}'
```
**Response**:
```json
{
"id": "chat-abc123",
"object": "chat.completion",
"created": 1234567890,
"model": "meta-llama/Meta-Llama-3-8B",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Quantum computing is..."
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 25,
"completion_tokens": 150,
"total_tokens": 175
}
}
```
### Streaming
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3-8B",
"messages": [{"role": "user", "content": "Count to 10"}],
"stream": true
}'
```
**Response** (SSE stream):
```
data: {"choices":[{"delta":{"content":"1"}}]}
data: {"choices":[{"delta":{"content":", 2"}}]}
data: {"choices":[{"delta":{"content":", 3"}}]}
data: [DONE]
```
### Completions
```bash
curl -X POST http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3-8B",
"prompt": "The capital of France is",
"max_tokens": 10,
"temperature": 0.0
}'
```
## Monitoring
### Prometheus Metrics
**Enable metrics**:
```bash
trtllm-serve meta-llama/Meta-Llama-3-8B \
--enable_metrics \
--metrics_port 9090
```
**Key metrics**:
```bash
# Scrape metrics
curl http://localhost:9090/metrics
# Important metrics:
# - trtllm_request_success_total - Total successful requests
# - trtllm_request_latency_seconds - Request latency histogram
# - trtllm_tokens_generated_total - Total tokens generated
# - trtllm_active_requests - Current active requests
# - trtllm_queue_size - Requests waiting in queue
# - trtllm_gpu_memory_usage_bytes - GPU memory usage
# - trtllm_kv_cache_usage_ratio - KV cache utilization
```
### Health Checks
```bash
# Readiness probe
curl http://localhost:8000/health/ready
# Liveness probe
curl http://localhost:8000/health/live
# Model info
curl http://localhost:8000/v1/models
```
**Kubernetes probes**:
```yaml
livenessProbe:
httpGet:
path: /health/live
port: 8000
initialDelaySeconds: 60
periodSeconds: 10
readinessProbe:
httpGet:
path: /health/ready
port: 8000
initialDelaySeconds: 30
periodSeconds: 5
```
## Production Deployment
### Docker Deployment
**Dockerfile**:
```dockerfile
FROM nvidia/tensorrt_llm:latest
# Copy any custom configs
COPY config.yaml /app/config.yaml
# Expose ports
EXPOSE 8000 9090
# Start server
CMD ["trtllm-serve", "meta-llama/Meta-Llama-3-8B", \
"--tp_size", "4", \
"--dtype", "fp8", \
"--max_batch_size", "256", \
"--enable_metrics", \
"--metrics_port", "9090"]
```
**Run container**:
```bash
docker run --gpus all -p 8000:8000 -p 9090:9090 \
tensorrt-llm:latest
```
### Kubernetes Deployment
**Complete deployment**:
```yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: tensorrt-llm
spec:
replicas: 2 # Multiple replicas for HA
selector:
matchLabels:
app: tensorrt-llm
template:
metadata:
labels:
app: tensorrt-llm
spec:
containers:
- name: trtllm
image: nvidia/tensorrt_llm:latest
command:
- trtllm-serve
- meta-llama/Meta-Llama-3-70B
- --tp_size=4
- --dtype=fp8
- --max_batch_size=256
- --enable_metrics
ports:
- containerPort: 8000
name: http
- containerPort: 9090
name: metrics
resources:
limits:
nvidia.com/gpu: 4
livenessProbe:
httpGet:
path: /health/live
port: 8000
readinessProbe:
httpGet:
path: /health/ready
port: 8000
---
apiVersion: v1
kind: Service
metadata:
name: tensorrt-llm
spec:
selector:
app: tensorrt-llm
ports:
- name: http
port: 80
targetPort: 8000
- name: metrics
port: 9090
targetPort: 9090
type: LoadBalancer
```
### Load Balancing
**NGINX configuration**:
```nginx
upstream tensorrt_llm {
least_conn; # Route to least busy server
server trtllm-1:8000 max_fails=3 fail_timeout=30s;
server trtllm-2:8000 max_fails=3 fail_timeout=30s;
server trtllm-3:8000 max_fails=3 fail_timeout=30s;
}
server {
listen 80;
location / {
proxy_pass http://tensorrt_llm;
proxy_read_timeout 300s; # Long timeout for slow generations
proxy_connect_timeout 10s;
}
}
```
## Autoscaling
### Horizontal Pod Autoscaler (HPA)
```yaml
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: tensorrt-llm-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: tensorrt-llm
minReplicas: 2
maxReplicas: 10
metrics:
- type: Pods
pods:
metric:
name: trtllm_active_requests
target:
type: AverageValue
averageValue: "50" # Scale when avg >50 active requests
```
### Custom Metrics
```yaml
# Scale based on queue size
- type: Pods
pods:
metric:
name: trtllm_queue_size
target:
type: AverageValue
averageValue: "10"
```
## Cost Optimization
### GPU Selection
**A100 80GB** ($3-4/hour):
- Use for: 70B models with FP8
- Throughput: 10,000-15,000 tok/s (TP=4)
- Cost per 1M tokens: $0.20-0.30
**H100 80GB** ($6-8/hour):
- Use for: 70B models with FP8, 405B models
- Throughput: 20,000-30,000 tok/s (TP=4)
- Cost per 1M tokens: $0.15-0.25 (2× faster = lower cost)
**L4** ($0.50-1/hour):
- Use for: 7-8B models
- Throughput: 1,000-2,000 tok/s
- Cost per 1M tokens: $0.25-0.50
### Batch Size Tuning
**Impact on cost**:
- Batch size 1: 1,000 tok/s → $3/hour per 1M = $3/M tokens
- Batch size 64: 5,000 tok/s → $3/hour per 5M = $0.60/M tokens
- **5× cost reduction** with batching
**Recommendation**: Target batch size 32-128 for cost efficiency.
## Security
### API Authentication
```bash
# Generate API key
export API_KEY=$(openssl rand -hex 32)
# Start server with authentication
trtllm-serve meta-llama/Meta-Llama-3-8B \
--api_key $API_KEY
# Client request
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Authorization: Bearer $API_KEY" \
-H "Content-Type: application/json" \
-d '{"model": "...", "messages": [...]}'
```
### Network Policies
```yaml
apiVersion: networking.k8s.io/v1
kind: NetworkPolicy
metadata:
name: tensorrt-llm-policy
spec:
podSelector:
matchLabels:
app: tensorrt-llm
policyTypes:
- Ingress
ingress:
- from:
- podSelector:
matchLabels:
app: api-gateway # Only allow from gateway
ports:
- protocol: TCP
port: 8000
```
## Troubleshooting
### High latency
**Diagnosis**:
```bash
# Check queue size
curl http://localhost:9090/metrics | grep queue_size
# Check active requests
curl http://localhost:9090/metrics | grep active_requests
```
**Solutions**:
- Scale horizontally (more replicas)
- Increase batch size (if GPU underutilized)
- Enable chunked context (if long prompts)
- Use FP8 quantization
### OOM crashes
**Solutions**:
- Reduce `max_batch_size`
- Reduce `max_num_tokens`
- Enable FP8 or INT4 quantization
- Increase `tensor_parallel_size`
### Timeout errors
**NGINX config**:
```nginx
proxy_read_timeout 600s; # 10 minutes for very long generations
proxy_send_timeout 600s;
```
## Best Practices
1. **Use FP8 on H100** for 2× speedup and 50% cost reduction
2. **Monitor metrics** - Set up Prometheus + Grafana
3. **Set readiness probes** - Prevent routing to unhealthy pods
4. **Use load balancing** - Distribute load across replicas
5. **Tune batch size** - Balance latency and throughput
6. **Enable streaming** - Better UX for chat applications
7. **Set up autoscaling** - Handle traffic spikes
8. **Use persistent volumes** - Cache compiled models
9. **Implement retries** - Handle transient failures
10. **Monitor costs** - Track cost per token

View file

@ -0,0 +1,367 @@
---
name: serving-llms-vllm
description: Serves LLMs with high throughput using vLLM's PagedAttention and continuous batching. Use when deploying production LLM APIs, optimizing inference latency/throughput, or serving models with limited GPU memory. Supports OpenAI-compatible endpoints, quantization (GPTQ/AWQ/FP8), and tensor parallelism.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [vllm, torch, transformers]
metadata:
hermes:
tags: [vLLM, Inference Serving, PagedAttention, Continuous Batching, High Throughput, Production, OpenAI API, Quantization, Tensor Parallelism]
---
# vLLM - High-Performance LLM Serving
## Quick start
vLLM achieves 24x higher throughput than standard transformers through PagedAttention (block-based KV cache) and continuous batching (mixing prefill/decode requests).
**Installation**:
```bash
pip install vllm
```
**Basic offline inference**:
```python
from vllm import LLM, SamplingParams
llm = LLM(model="meta-llama/Llama-3-8B-Instruct")
sampling = SamplingParams(temperature=0.7, max_tokens=256)
outputs = llm.generate(["Explain quantum computing"], sampling)
print(outputs[0].outputs[0].text)
```
**OpenAI-compatible server**:
```bash
vllm serve meta-llama/Llama-3-8B-Instruct
# Query with OpenAI SDK
python -c "
from openai import OpenAI
client = OpenAI(base_url='http://localhost:8000/v1', api_key='EMPTY')
print(client.chat.completions.create(
model='meta-llama/Llama-3-8B-Instruct',
messages=[{'role': 'user', 'content': 'Hello!'}]
).choices[0].message.content)
"
```
## Common workflows
### Workflow 1: Production API deployment
Copy this checklist and track progress:
```
Deployment Progress:
- [ ] Step 1: Configure server settings
- [ ] Step 2: Test with limited traffic
- [ ] Step 3: Enable monitoring
- [ ] Step 4: Deploy to production
- [ ] Step 5: Verify performance metrics
```
**Step 1: Configure server settings**
Choose configuration based on your model size:
```bash
# For 7B-13B models on single GPU
vllm serve meta-llama/Llama-3-8B-Instruct \
--gpu-memory-utilization 0.9 \
--max-model-len 8192 \
--port 8000
# For 30B-70B models with tensor parallelism
vllm serve meta-llama/Llama-2-70b-hf \
--tensor-parallel-size 4 \
--gpu-memory-utilization 0.9 \
--quantization awq \
--port 8000
# For production with caching and metrics
vllm serve meta-llama/Llama-3-8B-Instruct \
--gpu-memory-utilization 0.9 \
--enable-prefix-caching \
--enable-metrics \
--metrics-port 9090 \
--port 8000 \
--host 0.0.0.0
```
**Step 2: Test with limited traffic**
Run load test before production:
```bash
# Install load testing tool
pip install locust
# Create test_load.py with sample requests
# Run: locust -f test_load.py --host http://localhost:8000
```
Verify TTFT (time to first token) < 500ms and throughput > 100 req/sec.
**Step 3: Enable monitoring**
vLLM exposes Prometheus metrics on port 9090:
```bash
curl http://localhost:9090/metrics | grep vllm
```
Key metrics to monitor:
- `vllm:time_to_first_token_seconds` - Latency
- `vllm:num_requests_running` - Active requests
- `vllm:gpu_cache_usage_perc` - KV cache utilization
**Step 4: Deploy to production**
Use Docker for consistent deployment:
```bash
# Run vLLM in Docker
docker run --gpus all -p 8000:8000 \
vllm/vllm-openai:latest \
--model meta-llama/Llama-3-8B-Instruct \
--gpu-memory-utilization 0.9 \
--enable-prefix-caching
```
**Step 5: Verify performance metrics**
Check that deployment meets targets:
- TTFT < 500ms (for short prompts)
- Throughput > target req/sec
- GPU utilization > 80%
- No OOM errors in logs
### Workflow 2: Offline batch inference
For processing large datasets without server overhead.
Copy this checklist:
```
Batch Processing:
- [ ] Step 1: Prepare input data
- [ ] Step 2: Configure LLM engine
- [ ] Step 3: Run batch inference
- [ ] Step 4: Process results
```
**Step 1: Prepare input data**
```python
# Load prompts from file
prompts = []
with open("prompts.txt") as f:
prompts = [line.strip() for line in f]
print(f"Loaded {len(prompts)} prompts")
```
**Step 2: Configure LLM engine**
```python
from vllm import LLM, SamplingParams
llm = LLM(
model="meta-llama/Llama-3-8B-Instruct",
tensor_parallel_size=2, # Use 2 GPUs
gpu_memory_utilization=0.9,
max_model_len=4096
)
sampling = SamplingParams(
temperature=0.7,
top_p=0.95,
max_tokens=512,
stop=["</s>", "\n\n"]
)
```
**Step 3: Run batch inference**
vLLM automatically batches requests for efficiency:
```python
# Process all prompts in one call
outputs = llm.generate(prompts, sampling)
# vLLM handles batching internally
# No need to manually chunk prompts
```
**Step 4: Process results**
```python
# Extract generated text
results = []
for output in outputs:
prompt = output.prompt
generated = output.outputs[0].text
results.append({
"prompt": prompt,
"generated": generated,
"tokens": len(output.outputs[0].token_ids)
})
# Save to file
import json
with open("results.jsonl", "w") as f:
for result in results:
f.write(json.dumps(result) + "\n")
print(f"Processed {len(results)} prompts")
```
### Workflow 3: Quantized model serving
Fit large models in limited GPU memory.
```
Quantization Setup:
- [ ] Step 1: Choose quantization method
- [ ] Step 2: Find or create quantized model
- [ ] Step 3: Launch with quantization flag
- [ ] Step 4: Verify accuracy
```
**Step 1: Choose quantization method**
- **AWQ**: Best for 70B models, minimal accuracy loss
- **GPTQ**: Wide model support, good compression
- **FP8**: Fastest on H100 GPUs
**Step 2: Find or create quantized model**
Use pre-quantized models from HuggingFace:
```bash
# Search for AWQ models
# Example: TheBloke/Llama-2-70B-AWQ
```
**Step 3: Launch with quantization flag**
```bash
# Using pre-quantized model
vllm serve TheBloke/Llama-2-70B-AWQ \
--quantization awq \
--tensor-parallel-size 1 \
--gpu-memory-utilization 0.95
# Results: 70B model in ~40GB VRAM
```
**Step 4: Verify accuracy**
Test outputs match expected quality:
```python
# Compare quantized vs non-quantized responses
# Verify task-specific performance unchanged
```
## When to use vs alternatives
**Use vLLM when:**
- Deploying production LLM APIs (100+ req/sec)
- Serving OpenAI-compatible endpoints
- Limited GPU memory but need large models
- Multi-user applications (chatbots, assistants)
- Need low latency with high throughput
**Use alternatives instead:**
- **llama.cpp**: CPU/edge inference, single-user
- **HuggingFace transformers**: Research, prototyping, one-off generation
- **TensorRT-LLM**: NVIDIA-only, need absolute maximum performance
- **Text-Generation-Inference**: Already in HuggingFace ecosystem
## Common issues
**Issue: Out of memory during model loading**
Reduce memory usage:
```bash
vllm serve MODEL \
--gpu-memory-utilization 0.7 \
--max-model-len 4096
```
Or use quantization:
```bash
vllm serve MODEL --quantization awq
```
**Issue: Slow first token (TTFT > 1 second)**
Enable prefix caching for repeated prompts:
```bash
vllm serve MODEL --enable-prefix-caching
```
For long prompts, enable chunked prefill:
```bash
vllm serve MODEL --enable-chunked-prefill
```
**Issue: Model not found error**
Use `--trust-remote-code` for custom models:
```bash
vllm serve MODEL --trust-remote-code
```
**Issue: Low throughput (<50 req/sec)**
Increase concurrent sequences:
```bash
vllm serve MODEL --max-num-seqs 512
```
Check GPU utilization with `nvidia-smi` - should be >80%.
**Issue: Inference slower than expected**
Verify tensor parallelism uses power of 2 GPUs:
```bash
vllm serve MODEL --tensor-parallel-size 4 # Not 3
```
Enable speculative decoding for faster generation:
```bash
vllm serve MODEL --speculative-model DRAFT_MODEL
```
## Advanced topics
**Server deployment patterns**: See [references/server-deployment.md](references/server-deployment.md) for Docker, Kubernetes, and load balancing configurations.
**Performance optimization**: See [references/optimization.md](references/optimization.md) for PagedAttention tuning, continuous batching details, and benchmark results.
**Quantization guide**: See [references/quantization.md](references/quantization.md) for AWQ/GPTQ/FP8 setup, model preparation, and accuracy comparisons.
**Troubleshooting**: See [references/troubleshooting.md](references/troubleshooting.md) for detailed error messages, debugging steps, and performance diagnostics.
## Hardware requirements
- **Small models (7B-13B)**: 1x A10 (24GB) or A100 (40GB)
- **Medium models (30B-40B)**: 2x A100 (40GB) with tensor parallelism
- **Large models (70B+)**: 4x A100 (40GB) or 2x A100 (80GB), use AWQ/GPTQ
Supported platforms: NVIDIA (primary), AMD ROCm, Intel GPUs, TPUs
## Resources
- Official docs: https://docs.vllm.ai
- GitHub: https://github.com/vllm-project/vllm
- Paper: "Efficient Memory Management for Large Language Model Serving with PagedAttention" (SOSP 2023)
- Community: https://discuss.vllm.ai

View file

@ -0,0 +1,226 @@
# Performance Optimization
## Contents
- PagedAttention explained
- Continuous batching mechanics
- Prefix caching strategies
- Speculative decoding setup
- Benchmark results and comparisons
- Performance tuning guide
## PagedAttention explained
**Traditional attention problem**:
- KV cache stored in contiguous memory
- Wastes ~50% GPU memory due to fragmentation
- Cannot dynamically reallocate for varying sequence lengths
**PagedAttention solution**:
- Divides KV cache into fixed-size blocks (like OS virtual memory)
- Dynamic allocation from free block queue
- Shares blocks across sequences (for prefix caching)
**Memory savings example**:
```
Traditional: 70B model needs 160GB KV cache → OOM on 8x A100
PagedAttention: 70B model needs 80GB KV cache → Fits on 4x A100
```
**Configuration**:
```bash
# Block size (default: 16 tokens)
vllm serve MODEL --block-size 16
# Number of GPU blocks (auto-calculated)
# Controlled by --gpu-memory-utilization
vllm serve MODEL --gpu-memory-utilization 0.9
```
## Continuous batching mechanics
**Traditional batching**:
- Wait for all sequences in batch to finish
- GPU idle while waiting for longest sequence
- Low GPU utilization (~40-60%)
**Continuous batching**:
- Add new requests as slots become available
- Mix prefill (new requests) and decode (ongoing) in same batch
- High GPU utilization (>90%)
**Throughput improvement**:
```
Traditional batching: 50 req/sec @ 50% GPU util
Continuous batching: 200 req/sec @ 90% GPU util
= 4x throughput improvement
```
**Tuning parameters**:
```bash
# Max concurrent sequences (higher = more batching)
vllm serve MODEL --max-num-seqs 256
# Prefill/decode schedule (auto-balanced by default)
# No manual tuning needed
```
## Prefix caching strategies
Reuse computed KV cache for common prompt prefixes.
**Use cases**:
- System prompts repeated across requests
- Few-shot examples in every prompt
- RAG contexts with overlapping chunks
**Example savings**:
```
Prompt: [System: 500 tokens] + [User: 100 tokens]
Without caching: Compute 600 tokens every request
With caching: Compute 500 tokens once, then 100 tokens/request
= 83% faster TTFT
```
**Enable prefix caching**:
```bash
vllm serve MODEL --enable-prefix-caching
```
**Automatic prefix detection**:
- vLLM detects common prefixes automatically
- No code changes required
- Works with OpenAI-compatible API
**Cache hit rate monitoring**:
```bash
curl http://localhost:9090/metrics | grep cache_hit
# vllm_cache_hit_rate: 0.75 (75% hit rate)
```
## Speculative decoding setup
Use smaller "draft" model to propose tokens, larger model to verify.
**Speed improvement**:
```
Standard: Generate 1 token per forward pass
Speculative: Generate 3-5 tokens per forward pass
= 2-3x faster generation
```
**How it works**:
1. Draft model proposes K tokens (fast)
2. Target model verifies all K tokens in parallel (one pass)
3. Accept verified tokens, restart from first rejection
**Setup with separate draft model**:
```bash
vllm serve meta-llama/Llama-3-70B-Instruct \
--speculative-model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--num-speculative-tokens 5
```
**Setup with n-gram draft** (no separate model):
```bash
vllm serve MODEL \
--speculative-method ngram \
--num-speculative-tokens 3
```
**When to use**:
- Output length > 100 tokens
- Draft model 5-10x smaller than target
- Acceptable 2-3% accuracy trade-off
## Benchmark results
**vLLM vs HuggingFace Transformers** (Llama 3 8B, A100):
```
Metric | HF Transformers | vLLM | Improvement
------------------------|-----------------|--------|------------
Throughput (req/sec) | 12 | 280 | 23x
TTFT (ms) | 850 | 120 | 7x
Tokens/sec | 45 | 2,100 | 47x
GPU Memory (GB) | 28 | 16 | 1.75x less
```
**vLLM vs TensorRT-LLM** (Llama 2 70B, 4x A100):
```
Metric | TensorRT-LLM | vLLM | Notes
------------------------|--------------|--------|------------------
Throughput (req/sec) | 320 | 285 | TRT 12% faster
Setup complexity | High | Low | vLLM much easier
NVIDIA-only | Yes | No | vLLM multi-platform
Quantization support | FP8, INT8 | AWQ/GPTQ/FP8 | vLLM more options
```
## Performance tuning guide
**Step 1: Measure baseline**
```bash
# Install benchmarking tool
pip install locust
# Run baseline benchmark
vllm bench throughput \
--model MODEL \
--input-tokens 128 \
--output-tokens 256 \
--num-prompts 1000
# Record: throughput, TTFT, tokens/sec
```
**Step 2: Tune memory utilization**
```bash
# Try different values: 0.7, 0.85, 0.9, 0.95
vllm serve MODEL --gpu-memory-utilization 0.9
```
Higher = more batch capacity = higher throughput, but risk OOM.
**Step 3: Tune concurrency**
```bash
# Try values: 128, 256, 512, 1024
vllm serve MODEL --max-num-seqs 256
```
Higher = more batching opportunity, but may increase latency.
**Step 4: Enable optimizations**
```bash
vllm serve MODEL \
--enable-prefix-caching \ # For repeated prompts
--enable-chunked-prefill \ # For long prompts
--gpu-memory-utilization 0.9 \
--max-num-seqs 512
```
**Step 5: Re-benchmark and compare**
Target improvements:
- Throughput: +30-100%
- TTFT: -20-50%
- GPU utilization: >85%
**Common performance issues**:
**Low throughput (<50 req/sec)**:
- Increase `--max-num-seqs`
- Enable `--enable-prefix-caching`
- Check GPU utilization (should be >80%)
**High TTFT (>1 second)**:
- Enable `--enable-chunked-prefill`
- Reduce `--max-model-len` if possible
- Check if model is too large for GPU
**OOM errors**:
- Reduce `--gpu-memory-utilization` to 0.7
- Reduce `--max-model-len`
- Use quantization (`--quantization awq`)

View file

@ -0,0 +1,284 @@
# Quantization Guide
## Contents
- Quantization methods comparison
- AWQ setup and usage
- GPTQ setup and usage
- FP8 quantization (H100)
- Model preparation
- Accuracy vs compression trade-offs
## Quantization methods comparison
| Method | Compression | Accuracy Loss | Speed | Best For |
|--------|-------------|---------------|-------|----------|
| **AWQ** | 4-bit (75%) | <1% | Fast | 70B models, production |
| **GPTQ** | 4-bit (75%) | 1-2% | Fast | Wide model support |
| **FP8** | 8-bit (50%) | <0.5% | Fastest | H100 GPUs only |
| **SqueezeLLM** | 3-4 bit (75-80%) | 2-3% | Medium | Extreme compression |
**Recommendation**:
- **Production**: Use AWQ for 70B models
- **H100 GPUs**: Use FP8 for best speed
- **Maximum compatibility**: Use GPTQ
- **Extreme compression**: Use SqueezeLLM
## AWQ setup and usage
**AWQ** (Activation-aware Weight Quantization) achieves best accuracy at 4-bit.
**Step 1: Find pre-quantized model**
Search HuggingFace for AWQ models:
```bash
# Example: TheBloke/Llama-2-70B-AWQ
# Example: TheBloke/Mixtral-8x7B-Instruct-v0.1-AWQ
```
**Step 2: Launch with AWQ**
```bash
vllm serve TheBloke/Llama-2-70B-AWQ \
--quantization awq \
--tensor-parallel-size 1 \
--gpu-memory-utilization 0.95
```
**Memory savings**:
```
Llama 2 70B fp16: 140GB VRAM (4x A100 needed)
Llama 2 70B AWQ: 35GB VRAM (1x A100 40GB)
= 4x memory reduction
```
**Step 3: Verify performance**
Test that outputs are acceptable:
```python
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
# Test complex reasoning
response = client.chat.completions.create(
model="TheBloke/Llama-2-70B-AWQ",
messages=[{"role": "user", "content": "Explain quantum entanglement"}]
)
print(response.choices[0].message.content)
# Verify quality matches your requirements
```
**Quantize your own model** (requires GPU with 80GB+ VRAM):
```python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = "meta-llama/Llama-2-70b-hf"
quant_path = "llama-2-70b-awq"
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Quantize
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
model.quantize(tokenizer, quant_config=quant_config)
# Save
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
```
## GPTQ setup and usage
**GPTQ** has widest model support and good compression.
**Step 1: Find GPTQ model**
```bash
# Example: TheBloke/Llama-2-13B-GPTQ
# Example: TheBloke/CodeLlama-34B-GPTQ
```
**Step 2: Launch with GPTQ**
```bash
vllm serve TheBloke/Llama-2-13B-GPTQ \
--quantization gptq \
--dtype float16
```
**GPTQ configuration options**:
```bash
# Specify GPTQ parameters if needed
vllm serve MODEL \
--quantization gptq \
--gptq-act-order \ # Activation ordering
--dtype float16
```
**Quantize your own model**:
```python
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from transformers import AutoTokenizer
model_name = "meta-llama/Llama-2-13b-hf"
quantized_name = "llama-2-13b-gptq"
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoGPTQForCausalLM.from_pretrained(model_name, quantize_config)
# Prepare calibration data
calib_data = [...] # List of sample texts
# Quantize
quantize_config = BaseQuantizeConfig(
bits=4,
group_size=128,
desc_act=True
)
model.quantize(calib_data)
# Save
model.save_quantized(quantized_name)
```
## FP8 quantization (H100)
**FP8** (8-bit floating point) offers best speed on H100 GPUs with minimal accuracy loss.
**Requirements**:
- H100 or H800 GPU
- CUDA 12.3+ (12.8 recommended)
- Hopper architecture support
**Step 1: Enable FP8**
```bash
vllm serve meta-llama/Llama-3-70B-Instruct \
--quantization fp8 \
--tensor-parallel-size 2
```
**Performance gains on H100**:
```
fp16: 180 tokens/sec
FP8: 320 tokens/sec
= 1.8x speedup
```
**Step 2: Verify accuracy**
FP8 typically has <0.5% accuracy degradation:
```python
# Run evaluation suite
# Compare FP8 vs FP16 on your tasks
# Verify acceptable accuracy
```
**Dynamic FP8 quantization** (no pre-quantized model needed):
```bash
# vLLM automatically quantizes at runtime
vllm serve MODEL --quantization fp8
# No model preparation required
```
## Model preparation
**Pre-quantized models (easiest)**:
1. Search HuggingFace: `[model name] AWQ` or `[model name] GPTQ`
2. Download or use directly: `TheBloke/[Model]-AWQ`
3. Launch with appropriate `--quantization` flag
**Quantize your own model**:
**AWQ**:
```bash
# Install AutoAWQ
pip install autoawq
# Run quantization script
python quantize_awq.py --model MODEL --output OUTPUT
```
**GPTQ**:
```bash
# Install AutoGPTQ
pip install auto-gptq
# Run quantization script
python quantize_gptq.py --model MODEL --output OUTPUT
```
**Calibration data**:
- Use 128-512 diverse examples from target domain
- Representative of production inputs
- Higher quality calibration = better accuracy
## Accuracy vs compression trade-offs
**Empirical results** (Llama 2 70B on MMLU benchmark):
| Quantization | Accuracy | Memory | Speed | Production-Ready |
|--------------|----------|--------|-------|------------------|
| FP16 (baseline) | 100% | 140GB | 1.0x | ✅ (if memory available) |
| FP8 | 99.5% | 70GB | 1.8x | ✅ (H100 only) |
| AWQ 4-bit | 99.0% | 35GB | 1.5x | ✅ (best for 70B) |
| GPTQ 4-bit | 98.5% | 35GB | 1.5x | ✅ (good compatibility) |
| SqueezeLLM 3-bit | 96.0% | 26GB | 1.3x | ⚠️ (check accuracy) |
**When to use each**:
**No quantization (FP16)**:
- Have sufficient GPU memory
- Need absolute best accuracy
- Model <13B parameters
**FP8**:
- Using H100/H800 GPUs
- Need best speed with minimal accuracy loss
- Production deployment
**AWQ 4-bit**:
- Need to fit 70B model in 40GB GPU
- Production deployment
- <1% accuracy loss acceptable
**GPTQ 4-bit**:
- Wide model support needed
- Not on H100 (use FP8 instead)
- 1-2% accuracy loss acceptable
**Testing strategy**:
1. **Baseline**: Measure FP16 accuracy on your evaluation set
2. **Quantize**: Create quantized version
3. **Evaluate**: Compare quantized vs baseline on same tasks
4. **Decide**: Accept if degradation < threshold (typically 1-2%)
**Example evaluation**:
```python
from evaluate import load_evaluation_suite
# Run on FP16 baseline
baseline_score = evaluate(model_fp16, eval_suite)
# Run on quantized
quant_score = evaluate(model_awq, eval_suite)
# Compare
degradation = (baseline_score - quant_score) / baseline_score * 100
print(f"Accuracy degradation: {degradation:.2f}%")
# Decision
if degradation < 1.0:
print("✅ Quantization acceptable for production")
else:
print("⚠️ Review accuracy loss")
```

View file

@ -0,0 +1,255 @@
# Server Deployment Patterns
## Contents
- Docker deployment
- Kubernetes deployment
- Load balancing with Nginx
- Multi-node distributed serving
- Production configuration examples
- Health checks and monitoring
## Docker deployment
**Basic Dockerfile**:
```dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04
RUN apt-get update && apt-get install -y python3-pip
RUN pip install vllm
EXPOSE 8000
CMD ["vllm", "serve", "meta-llama/Llama-3-8B-Instruct", \
"--host", "0.0.0.0", "--port", "8000", \
"--gpu-memory-utilization", "0.9"]
```
**Build and run**:
```bash
docker build -t vllm-server .
docker run --gpus all -p 8000:8000 vllm-server
```
**Docker Compose** (with metrics):
```yaml
version: '3.8'
services:
vllm:
image: vllm/vllm-openai:latest
command: >
--model meta-llama/Llama-3-8B-Instruct
--gpu-memory-utilization 0.9
--enable-metrics
--metrics-port 9090
ports:
- "8000:8000"
- "9090:9090"
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
```
## Kubernetes deployment
**Deployment manifest**:
```yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: vllm-server
spec:
replicas: 2
selector:
matchLabels:
app: vllm
template:
metadata:
labels:
app: vllm
spec:
containers:
- name: vllm
image: vllm/vllm-openai:latest
args:
- "--model=meta-llama/Llama-3-8B-Instruct"
- "--gpu-memory-utilization=0.9"
- "--enable-prefix-caching"
resources:
limits:
nvidia.com/gpu: 1
ports:
- containerPort: 8000
name: http
- containerPort: 9090
name: metrics
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 60
periodSeconds: 30
---
apiVersion: v1
kind: Service
metadata:
name: vllm-service
spec:
selector:
app: vllm
ports:
- port: 8000
targetPort: 8000
name: http
- port: 9090
targetPort: 9090
name: metrics
type: LoadBalancer
```
## Load balancing with Nginx
**Nginx configuration**:
```nginx
upstream vllm_backend {
least_conn; # Route to least-loaded server
server localhost:8001;
server localhost:8002;
server localhost:8003;
}
server {
listen 80;
location / {
proxy_pass http://vllm_backend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
# Timeouts for long-running inference
proxy_read_timeout 300s;
proxy_connect_timeout 75s;
}
# Metrics endpoint
location /metrics {
proxy_pass http://localhost:9090/metrics;
}
}
```
**Start multiple vLLM instances**:
```bash
# Terminal 1
vllm serve MODEL --port 8001 --tensor-parallel-size 1
# Terminal 2
vllm serve MODEL --port 8002 --tensor-parallel-size 1
# Terminal 3
vllm serve MODEL --port 8003 --tensor-parallel-size 1
# Start Nginx
nginx -c /path/to/nginx.conf
```
## Multi-node distributed serving
For models too large for single node:
**Node 1** (master):
```bash
export MASTER_ADDR=192.168.1.10
export MASTER_PORT=29500
export RANK=0
export WORLD_SIZE=2
vllm serve meta-llama/Llama-2-70b-hf \
--tensor-parallel-size 8 \
--pipeline-parallel-size 2
```
**Node 2** (worker):
```bash
export MASTER_ADDR=192.168.1.10
export MASTER_PORT=29500
export RANK=1
export WORLD_SIZE=2
vllm serve meta-llama/Llama-2-70b-hf \
--tensor-parallel-size 8 \
--pipeline-parallel-size 2
```
## Production configuration examples
**High throughput** (batch-heavy workload):
```bash
vllm serve MODEL \
--max-num-seqs 512 \
--gpu-memory-utilization 0.95 \
--enable-prefix-caching \
--trust-remote-code
```
**Low latency** (interactive workload):
```bash
vllm serve MODEL \
--max-num-seqs 64 \
--gpu-memory-utilization 0.85 \
--enable-chunked-prefill
```
**Memory-constrained** (40GB GPU for 70B model):
```bash
vllm serve TheBloke/Llama-2-70B-AWQ \
--quantization awq \
--tensor-parallel-size 1 \
--gpu-memory-utilization 0.95 \
--max-model-len 4096
```
## Health checks and monitoring
**Health check endpoint**:
```bash
curl http://localhost:8000/health
# Returns: {"status": "ok"}
```
**Readiness check** (wait for model loaded):
```bash
#!/bin/bash
until curl -f http://localhost:8000/health; do
echo "Waiting for vLLM to be ready..."
sleep 5
done
echo "vLLM is ready!"
```
**Prometheus scraping**:
```yaml
# prometheus.yml
scrape_configs:
- job_name: 'vllm'
static_configs:
- targets: ['localhost:9090']
metrics_path: '/metrics'
scrape_interval: 15s
```
**Grafana dashboard** (key metrics):
- Requests per second: `rate(vllm_request_success_total[5m])`
- TTFT p50: `histogram_quantile(0.5, vllm_time_to_first_token_seconds_bucket)`
- TTFT p99: `histogram_quantile(0.99, vllm_time_to_first_token_seconds_bucket)`
- GPU cache usage: `vllm_gpu_cache_usage_perc`
- Active requests: `vllm_num_requests_running`

View file

@ -0,0 +1,447 @@
# Troubleshooting Guide
## Contents
- Out of memory (OOM) errors
- Performance issues
- Model loading errors
- Network and connection issues
- Quantization problems
- Distributed serving issues
- Debugging tools and commands
## Out of memory (OOM) errors
### Symptom: `torch.cuda.OutOfMemoryError` during model loading
**Cause**: Model + KV cache exceeds available VRAM
**Solutions (try in order)**:
1. **Reduce GPU memory utilization**:
```bash
vllm serve MODEL --gpu-memory-utilization 0.7 # Try 0.7, 0.75, 0.8
```
2. **Reduce max sequence length**:
```bash
vllm serve MODEL --max-model-len 4096 # Instead of 8192
```
3. **Enable quantization**:
```bash
vllm serve MODEL --quantization awq # 4x memory reduction
```
4. **Use tensor parallelism** (multiple GPUs):
```bash
vllm serve MODEL --tensor-parallel-size 2 # Split across 2 GPUs
```
5. **Reduce max concurrent sequences**:
```bash
vllm serve MODEL --max-num-seqs 128 # Default is 256
```
### Symptom: OOM during inference (not model loading)
**Cause**: KV cache fills up during generation
**Solutions**:
```bash
# Reduce KV cache allocation
vllm serve MODEL --gpu-memory-utilization 0.85
# Reduce batch size
vllm serve MODEL --max-num-seqs 64
# Reduce max tokens per request
# Set in client request: max_tokens=512
```
### Symptom: OOM with quantized model
**Cause**: Quantization overhead or incorrect configuration
**Solution**:
```bash
# Ensure quantization flag matches model
vllm serve TheBloke/Llama-2-70B-AWQ --quantization awq # Must specify
# Try different dtype
vllm serve MODEL --quantization awq --dtype float16
```
## Performance issues
### Symptom: Low throughput (<50 req/sec expected >100)
**Diagnostic steps**:
1. **Check GPU utilization**:
```bash
watch -n 1 nvidia-smi
# GPU utilization should be >80%
```
If <80%, increase concurrent requests:
```bash
vllm serve MODEL --max-num-seqs 512 # Increase from 256
```
2. **Check if memory-bound**:
```bash
# If memory at 100% but GPU <80%, reduce sequence length
vllm serve MODEL --max-model-len 4096
```
3. **Enable optimizations**:
```bash
vllm serve MODEL \
--enable-prefix-caching \
--enable-chunked-prefill \
--max-num-seqs 512
```
4. **Check tensor parallelism settings**:
```bash
# Must use power-of-2 GPUs
vllm serve MODEL --tensor-parallel-size 4 # Not 3 or 5
```
### Symptom: High TTFT (time to first token >1 second)
**Causes and solutions**:
**Long prompts**:
```bash
vllm serve MODEL --enable-chunked-prefill
```
**No prefix caching**:
```bash
vllm serve MODEL --enable-prefix-caching # For repeated prompts
```
**Too many concurrent requests**:
```bash
vllm serve MODEL --max-num-seqs 64 # Reduce to prioritize latency
```
**Model too large for single GPU**:
```bash
vllm serve MODEL --tensor-parallel-size 2 # Parallelize prefill
```
### Symptom: Slow token generation (low tokens/sec)
**Diagnostic**:
```bash
# Check if model is correct size
vllm serve MODEL # Should see model size in logs
# Check speculative decoding
vllm serve MODEL --speculative-model DRAFT_MODEL
```
**For H100 GPUs**, enable FP8:
```bash
vllm serve MODEL --quantization fp8
```
## Model loading errors
### Symptom: `OSError: MODEL not found`
**Causes**:
1. **Model name typo**:
```bash
# Check exact model name on HuggingFace
vllm serve meta-llama/Llama-3-8B-Instruct # Correct capitalization
```
2. **Private/gated model**:
```bash
# Login to HuggingFace first
huggingface-cli login
# Then run vLLM
vllm serve meta-llama/Llama-3-70B-Instruct
```
3. **Custom model needs trust flag**:
```bash
vllm serve MODEL --trust-remote-code
```
### Symptom: `ValueError: Tokenizer not found`
**Solution**:
```bash
# Download model manually first
python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('MODEL')"
# Then launch vLLM
vllm serve MODEL
```
### Symptom: `ImportError: No module named 'flash_attn'`
**Solution**:
```bash
# Install flash attention
pip install flash-attn --no-build-isolation
# Or disable flash attention
vllm serve MODEL --disable-flash-attn
```
## Network and connection issues
### Symptom: `Connection refused` when querying server
**Diagnostic**:
1. **Check server is running**:
```bash
curl http://localhost:8000/health
```
2. **Check port binding**:
```bash
# Bind to all interfaces for remote access
vllm serve MODEL --host 0.0.0.0 --port 8000
# Check if port is in use
lsof -i :8000
```
3. **Check firewall**:
```bash
# Allow port through firewall
sudo ufw allow 8000
```
### Symptom: Slow response times over network
**Solutions**:
1. **Increase timeout**:
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="EMPTY",
timeout=300.0 # 5 minute timeout
)
```
2. **Check network latency**:
```bash
ping SERVER_IP # Should be <10ms for local network
```
3. **Use connection pooling**:
```python
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
session = requests.Session()
retries = Retry(total=3, backoff_factor=1)
session.mount('http://', HTTPAdapter(max_retries=retries))
```
## Quantization problems
### Symptom: `RuntimeError: Quantization format not supported`
**Solution**:
```bash
# Ensure correct quantization method
vllm serve MODEL --quantization awq # For AWQ models
vllm serve MODEL --quantization gptq # For GPTQ models
# Check model card for quantization type
```
### Symptom: Poor quality outputs after quantization
**Diagnostic**:
1. **Verify model is correctly quantized**:
```bash
# Check model config.json for quantization_config
cat ~/.cache/huggingface/hub/models--MODEL/config.json
```
2. **Try different quantization method**:
```bash
# If AWQ quality issues, try FP8 (H100 only)
vllm serve MODEL --quantization fp8
# Or use less aggressive quantization
vllm serve MODEL # No quantization
```
3. **Increase temperature for better diversity**:
```python
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
```
## Distributed serving issues
### Symptom: `RuntimeError: Distributed init failed`
**Diagnostic**:
1. **Check environment variables**:
```bash
# On all nodes
echo $MASTER_ADDR # Should be same
echo $MASTER_PORT # Should be same
echo $RANK # Should be unique per node (0, 1, 2, ...)
echo $WORLD_SIZE # Should be same (total nodes)
```
2. **Check network connectivity**:
```bash
# From node 1 to node 2
ping NODE2_IP
nc -zv NODE2_IP 29500 # Check port accessibility
```
3. **Check NCCL settings**:
```bash
export NCCL_DEBUG=INFO
export NCCL_SOCKET_IFNAME=eth0 # Or your network interface
vllm serve MODEL --tensor-parallel-size 8
```
### Symptom: `NCCL error: unhandled cuda error`
**Solutions**:
```bash
# Set NCCL to use correct network interface
export NCCL_SOCKET_IFNAME=eth0 # Replace with your interface
# Increase timeout
export NCCL_TIMEOUT=1800 # 30 minutes
# Force P2P for debugging
export NCCL_P2P_DISABLE=1
```
## Debugging tools and commands
### Enable debug logging
```bash
export VLLM_LOGGING_LEVEL=DEBUG
vllm serve MODEL
```
### Monitor GPU usage
```bash
# Real-time GPU monitoring
watch -n 1 nvidia-smi
# Memory breakdown
nvidia-smi --query-gpu=memory.used,memory.free --format=csv -l 1
```
### Profile performance
```bash
# Built-in benchmarking
vllm bench throughput \
--model MODEL \
--input-tokens 128 \
--output-tokens 256 \
--num-prompts 100
vllm bench latency \
--model MODEL \
--input-tokens 128 \
--output-tokens 256 \
--batch-size 8
```
### Check metrics
```bash
# Prometheus metrics
curl http://localhost:9090/metrics
# Filter for specific metrics
curl http://localhost:9090/metrics | grep vllm_time_to_first_token
# Key metrics to monitor:
# - vllm_time_to_first_token_seconds
# - vllm_time_per_output_token_seconds
# - vllm_num_requests_running
# - vllm_gpu_cache_usage_perc
# - vllm_request_success_total
```
### Test server health
```bash
# Health check
curl http://localhost:8000/health
# Model info
curl http://localhost:8000/v1/models
# Test completion
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "MODEL",
"prompt": "Hello",
"max_tokens": 10
}'
```
### Common environment variables
```bash
# CUDA settings
export CUDA_VISIBLE_DEVICES=0,1,2,3 # Limit to specific GPUs
# vLLM settings
export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_TRACE_FUNCTION=1 # Profile functions
export VLLM_USE_V1=1 # Use v1.0 engine (faster)
# NCCL settings (distributed)
export NCCL_DEBUG=INFO
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_DISABLE=0 # Enable InfiniBand
```
### Collect diagnostic info for bug reports
```bash
# System info
nvidia-smi
python --version
pip show vllm
# vLLM version and config
vllm --version
python -c "import vllm; print(vllm.__version__)"
# Run with debug logging
export VLLM_LOGGING_LEVEL=DEBUG
vllm serve MODEL 2>&1 | tee vllm_debug.log
# Include in bug report:
# - vllm_debug.log
# - nvidia-smi output
# - Full command used
# - Expected vs actual behavior
```

View file

@ -0,0 +1,3 @@
---
description: Specific model architectures and tools — computer vision (CLIP, SAM, Stable Diffusion), speech (Whisper), audio generation (AudioCraft), and multimodal models (LLaVA).
---

View file

@ -0,0 +1,567 @@
---
name: audiocraft-audio-generation
description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
metadata:
hermes:
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
---
# AudioCraft: Audio Generation
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
## When to use AudioCraft
**Use AudioCraft when:**
- Need to generate music from text descriptions
- Creating sound effects and environmental audio
- Building music generation applications
- Need melody-conditioned music generation
- Want stereo audio output
- Require controllable music generation with style transfer
**Key features:**
- **MusicGen**: Text-to-music generation with melody conditioning
- **AudioGen**: Text-to-sound effects generation
- **EnCodec**: High-fidelity neural audio codec
- **Multiple model sizes**: Small (300M) to Large (3.3B)
- **Stereo support**: Full stereo audio generation
- **Style conditioning**: MusicGen-Style for reference-based generation
**Use alternatives instead:**
- **Stable Audio**: For longer commercial music generation
- **Bark**: For text-to-speech with music/sound effects
- **Riffusion**: For spectogram-based music generation
- **OpenAI Jukebox**: For raw audio generation with lyrics
## Quick start
### Installation
```bash
# From PyPI
pip install audiocraft
# From GitHub (latest)
pip install git+https://github.com/facebookresearch/audiocraft.git
# Or use HuggingFace Transformers
pip install transformers torch torchaudio
```
### Basic text-to-music (AudioCraft)
```python
import torchaudio
from audiocraft.models import MusicGen
# Load model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Set generation parameters
model.set_generation_params(
duration=8, # seconds
top_k=250,
temperature=1.0
)
# Generate from text
descriptions = ["happy upbeat electronic dance music with synths"]
wav = model.generate(descriptions)
# Save audio
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
```
### Using HuggingFace Transformers
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy
# Load model and processor
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
model.to("cuda")
# Generate music
inputs = processor(
text=["80s pop track with bassy drums and synth"],
padding=True,
return_tensors="pt"
).to("cuda")
audio_values = model.generate(
**inputs,
do_sample=True,
guidance_scale=3,
max_new_tokens=256
)
# Save
sampling_rate = model.config.audio_encoder.sampling_rate
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
```
### Text-to-sound with AudioGen
```python
from audiocraft.models import AudioGen
# Load AudioGen
model = AudioGen.get_pretrained('facebook/audiogen-medium')
model.set_generation_params(duration=5)
# Generate sound effects
descriptions = ["dog barking in a park with birds chirping"]
wav = model.generate(descriptions)
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
```
## Core concepts
### Architecture overview
```
AudioCraft Architecture:
┌──────────────────────────────────────────────────────────────┐
│ Text Encoder (T5) │
│ │ │
│ Text Embeddings │
└────────────────────────┬─────────────────────────────────────┘
┌────────────────────────▼─────────────────────────────────────┐
│ Transformer Decoder (LM) │
│ Auto-regressively generates audio tokens │
│ Using efficient token interleaving patterns │
└────────────────────────┬─────────────────────────────────────┘
┌────────────────────────▼─────────────────────────────────────┐
│ EnCodec Audio Decoder │
│ Converts tokens back to audio waveform │
└──────────────────────────────────────────────────────────────┘
```
### Model variants
| Model | Size | Description | Use Case |
|-------|------|-------------|----------|
| `musicgen-small` | 300M | Text-to-music | Quick generation |
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
### Generation parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `duration` | 8.0 | Length in seconds (1-120) |
| `top_k` | 250 | Top-k sampling |
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
| `temperature` | 1.0 | Sampling temperature |
| `cfg_coef` | 3.0 | Classifier-free guidance |
## MusicGen usage
### Text-to-music generation
```python
from audiocraft.models import MusicGen
import torchaudio
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# Configure generation
model.set_generation_params(
duration=30, # Up to 30 seconds
top_k=250, # Sampling diversity
top_p=0.0, # 0 = use top_k only
temperature=1.0, # Creativity (higher = more varied)
cfg_coef=3.0 # Text adherence (higher = stricter)
)
# Generate multiple samples
descriptions = [
"epic orchestral soundtrack with strings and brass",
"chill lo-fi hip hop beat with jazzy piano",
"energetic rock song with electric guitar"
]
# Generate (returns [batch, channels, samples])
wav = model.generate(descriptions)
# Save each
for i, audio in enumerate(wav):
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
```
### Melody-conditioned generation
```python
from audiocraft.models import MusicGen
import torchaudio
# Load melody model
model = MusicGen.get_pretrained('facebook/musicgen-melody')
model.set_generation_params(duration=30)
# Load melody audio
melody, sr = torchaudio.load("melody.wav")
# Generate with melody conditioning
descriptions = ["acoustic guitar folk song"]
wav = model.generate_with_chroma(descriptions, melody, sr)
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
```
### Stereo generation
```python
from audiocraft.models import MusicGen
# Load stereo model
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
model.set_generation_params(duration=15)
descriptions = ["ambient electronic music with wide stereo panning"]
wav = model.generate(descriptions)
# wav shape: [batch, 2, samples] for stereo
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
```
### Audio continuation
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
# Load audio to continue
import torchaudio
audio, sr = torchaudio.load("intro.wav")
# Process with text and audio
inputs = processor(
audio=audio.squeeze().numpy(),
sampling_rate=sr,
text=["continue with a epic chorus"],
padding=True,
return_tensors="pt"
)
# Generate continuation
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
```
## MusicGen-Style usage
### Style-conditioned generation
```python
from audiocraft.models import MusicGen
# Load style model
model = MusicGen.get_pretrained('facebook/musicgen-style')
# Configure generation with style
model.set_generation_params(
duration=30,
cfg_coef=3.0,
cfg_coef_beta=5.0 # Style influence
)
# Configure style conditioner
model.set_style_conditioner_params(
eval_q=3, # RVQ quantizers (1-6)
excerpt_length=3.0 # Style excerpt length
)
# Load style reference
style_audio, sr = torchaudio.load("reference_style.wav")
# Generate with text + style
descriptions = ["upbeat dance track"]
wav = model.generate_with_style(descriptions, style_audio, sr)
```
### Style-only generation (no text)
```python
# Generate matching style without text prompt
model.set_generation_params(
duration=30,
cfg_coef=3.0,
cfg_coef_beta=None # Disable double CFG for style-only
)
wav = model.generate_with_style([None], style_audio, sr)
```
## AudioGen usage
### Sound effect generation
```python
from audiocraft.models import AudioGen
import torchaudio
model = AudioGen.get_pretrained('facebook/audiogen-medium')
model.set_generation_params(duration=10)
# Generate various sounds
descriptions = [
"thunderstorm with heavy rain and lightning",
"busy city traffic with car horns",
"ocean waves crashing on rocks",
"crackling campfire in forest"
]
wav = model.generate(descriptions)
for i, audio in enumerate(wav):
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
```
## EnCodec usage
### Audio compression
```python
from audiocraft.models import CompressionModel
import torch
import torchaudio
# Load EnCodec
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
# Load audio
wav, sr = torchaudio.load("audio.wav")
# Ensure correct sample rate
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
wav = resampler(wav)
# Encode to tokens
with torch.no_grad():
encoded = model.encode(wav.unsqueeze(0))
codes = encoded[0] # Audio codes
# Decode back to audio
with torch.no_grad():
decoded = model.decode(codes)
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
```
## Common workflows
### Workflow 1: Music generation pipeline
```python
import torch
import torchaudio
from audiocraft.models import MusicGen
class MusicGenerator:
def __init__(self, model_name="facebook/musicgen-medium"):
self.model = MusicGen.get_pretrained(model_name)
self.sample_rate = 32000
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
self.model.set_generation_params(
duration=duration,
top_k=250,
temperature=temperature,
cfg_coef=cfg
)
with torch.no_grad():
wav = self.model.generate([prompt])
return wav[0].cpu()
def generate_batch(self, prompts, duration=30):
self.model.set_generation_params(duration=duration)
with torch.no_grad():
wav = self.model.generate(prompts)
return wav.cpu()
def save(self, audio, path):
torchaudio.save(path, audio, sample_rate=self.sample_rate)
# Usage
generator = MusicGenerator()
audio = generator.generate(
"epic cinematic orchestral music",
duration=30,
temperature=1.0
)
generator.save(audio, "epic_music.wav")
```
### Workflow 2: Sound design batch processing
```python
import json
from pathlib import Path
from audiocraft.models import AudioGen
import torchaudio
def batch_generate_sounds(sound_specs, output_dir):
"""
Generate multiple sounds from specifications.
Args:
sound_specs: list of {"name": str, "description": str, "duration": float}
output_dir: output directory path
"""
model = AudioGen.get_pretrained('facebook/audiogen-medium')
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
results = []
for spec in sound_specs:
model.set_generation_params(duration=spec.get("duration", 5))
wav = model.generate([spec["description"]])
output_path = output_dir / f"{spec['name']}.wav"
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
results.append({
"name": spec["name"],
"path": str(output_path),
"description": spec["description"]
})
return results
# Usage
sounds = [
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
]
results = batch_generate_sounds(sounds, "sound_effects/")
```
### Workflow 3: Gradio demo
```python
import gradio as gr
import torch
import torchaudio
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
def generate_music(prompt, duration, temperature, cfg_coef):
model.set_generation_params(
duration=duration,
temperature=temperature,
cfg_coef=cfg_coef
)
with torch.no_grad():
wav = model.generate([prompt])
# Save to temp file
path = "temp_output.wav"
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
return path
demo = gr.Interface(
fn=generate_music,
inputs=[
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
],
outputs=gr.Audio(label="Generated Music"),
title="MusicGen Demo"
)
demo.launch()
```
## Performance optimization
### Memory optimization
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Clear cache between generations
torch.cuda.empty_cache()
# Generate shorter durations
model.set_generation_params(duration=10) # Instead of 30
# Use half precision
model = model.half()
```
### Batch processing efficiency
```python
# Process multiple prompts at once (more efficient)
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
wav = model.generate(descriptions) # Single batch
# Instead of
for desc in descriptions:
wav = model.generate([desc]) # Multiple batches (slower)
```
### GPU memory requirements
| Model | FP32 VRAM | FP16 VRAM |
|-------|-----------|-----------|
| musicgen-small | ~4GB | ~2GB |
| musicgen-medium | ~8GB | ~4GB |
| musicgen-large | ~16GB | ~8GB |
## Common issues
| Issue | Solution |
|-------|----------|
| CUDA OOM | Use smaller model, reduce duration |
| Poor quality | Increase cfg_coef, better prompts |
| Generation too short | Check max duration setting |
| Audio artifacts | Try different temperature |
| Stereo not working | Use stereo model variant |
## References
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
## Resources
- **GitHub**: https://github.com/facebookresearch/audiocraft
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen

View file

@ -0,0 +1,666 @@
# AudioCraft Advanced Usage Guide
## Fine-tuning MusicGen
### Custom dataset preparation
```python
import os
import json
from pathlib import Path
import torchaudio
def prepare_dataset(audio_dir, output_dir, metadata_file):
"""
Prepare dataset for MusicGen fine-tuning.
Directory structure:
output_dir/
├── audio/
│ ├── 0001.wav
│ ├── 0002.wav
│ └── ...
└── metadata.json
"""
output_dir = Path(output_dir)
audio_output = output_dir / "audio"
audio_output.mkdir(parents=True, exist_ok=True)
# Load metadata (format: {"path": "...", "description": "..."})
with open(metadata_file) as f:
metadata = json.load(f)
processed = []
for idx, item in enumerate(metadata):
audio_path = Path(audio_dir) / item["path"]
# Load and resample to 32kHz
wav, sr = torchaudio.load(str(audio_path))
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
wav = resampler(wav)
# Convert to mono if stereo
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
# Save processed audio
output_path = audio_output / f"{idx:04d}.wav"
torchaudio.save(str(output_path), wav, sample_rate=32000)
processed.append({
"path": str(output_path.relative_to(output_dir)),
"description": item["description"],
"duration": wav.shape[1] / 32000
})
# Save processed metadata
with open(output_dir / "metadata.json", "w") as f:
json.dump(processed, f, indent=2)
print(f"Processed {len(processed)} samples")
return processed
```
### Fine-tuning with dora
```bash
# AudioCraft uses dora for experiment management
# Install dora
pip install dora-search
# Clone AudioCraft
git clone https://github.com/facebookresearch/audiocraft.git
cd audiocraft
# Create config for fine-tuning
cat > config/solver/musicgen/finetune.yaml << 'EOF'
defaults:
- musicgen/musicgen_base
- /model: lm/musicgen_lm
- /conditioner: cond_base
solver: musicgen
autocast: true
autocast_dtype: float16
optim:
epochs: 100
batch_size: 4
lr: 1e-4
ema: 0.999
optimizer: adamw
dataset:
batch_size: 4
num_workers: 4
train:
- dset: your_dataset
root: /path/to/dataset
valid:
- dset: your_dataset
root: /path/to/dataset
checkpoint:
save_every: 10
keep_every_states: null
EOF
# Run fine-tuning
dora run solver=musicgen/finetune
```
### LoRA fine-tuning
```python
from peft import LoraConfig, get_peft_model
from audiocraft.models import MusicGen
import torch
# Load base model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Get the language model component
lm = model.lm
# Configure LoRA
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
lora_dropout=0.05,
bias="none"
)
# Apply LoRA
lm = get_peft_model(lm, lora_config)
lm.print_trainable_parameters()
```
## Multi-GPU Training
### DataParallel
```python
import torch
import torch.nn as nn
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Wrap LM with DataParallel
if torch.cuda.device_count() > 1:
model.lm = nn.DataParallel(model.lm)
model.to("cuda")
```
### DistributedDataParallel
```python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train(rank, world_size):
setup(rank, world_size)
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.lm = model.lm.to(rank)
model.lm = DDP(model.lm, device_ids=[rank])
# Training loop
# ...
dist.destroy_process_group()
```
## Custom Conditioning
### Adding new conditioners
```python
from audiocraft.modules.conditioners import BaseConditioner
import torch
class CustomConditioner(BaseConditioner):
"""Custom conditioner for additional control signals."""
def __init__(self, dim, output_dim):
super().__init__(dim, output_dim)
self.embed = torch.nn.Linear(dim, output_dim)
def forward(self, x):
return self.embed(x)
def tokenize(self, x):
# Tokenize input for conditioning
return x
# Use with MusicGen
from audiocraft.models.builders import get_lm_model
# Modify model config to include custom conditioner
# This requires editing the model configuration
```
### Melody conditioning internals
```python
from audiocraft.models import MusicGen
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
import torch
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# Access chroma extractor
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
# Manual chroma extraction
def extract_chroma(audio, sr):
"""Extract chroma features from audio."""
import librosa
# Compute chroma
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
return torch.from_numpy(chroma).float()
# Use extracted chroma for conditioning
chroma = extract_chroma(melody_audio, sample_rate)
```
## EnCodec Deep Dive
### Custom compression settings
```python
from audiocraft.models import CompressionModel
import torch
# Load EnCodec
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
# Access codec parameters
print(f"Sample rate: {encodec.sample_rate}")
print(f"Channels: {encodec.channels}")
print(f"Cardinality: {encodec.cardinality}") # Codebook size
print(f"Num codebooks: {encodec.num_codebooks}")
print(f"Frame rate: {encodec.frame_rate}")
# Encode with specific bandwidth
# Lower bandwidth = more compression, lower quality
encodec.set_target_bandwidth(6.0) # 6 kbps
audio = torch.randn(1, 1, 32000) # 1 second
encoded = encodec.encode(audio)
decoded = encodec.decode(encoded[0])
```
### Streaming encoding
```python
import torch
from audiocraft.models import CompressionModel
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
def encode_streaming(audio_stream, chunk_size=32000):
"""Encode audio in streaming fashion."""
all_codes = []
for chunk in audio_stream:
# Ensure chunk is right shape
if chunk.dim() == 1:
chunk = chunk.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
codes = encodec.encode(chunk)[0]
all_codes.append(codes)
return torch.cat(all_codes, dim=-1)
def decode_streaming(codes_stream, output_stream):
"""Decode codes in streaming fashion."""
for codes in codes_stream:
with torch.no_grad():
audio = encodec.decode(codes)
output_stream.write(audio.cpu().numpy())
```
## MultiBand Diffusion
### Using MBD for enhanced quality
```python
from audiocraft.models import MusicGen, MultiBandDiffusion
# Load MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# Load MultiBand Diffusion
mbd = MultiBandDiffusion.get_mbd_musicgen()
model.set_generation_params(duration=10)
# Generate with standard decoder
descriptions = ["epic orchestral music"]
wav_standard = model.generate(descriptions)
# Generate tokens and use MBD decoder
with torch.no_grad():
# Get tokens
gen_tokens = model.generate_tokens(descriptions)
# Decode with MBD
wav_mbd = mbd.tokens_to_wav(gen_tokens)
# Compare quality
print(f"Standard shape: {wav_standard.shape}")
print(f"MBD shape: {wav_mbd.shape}")
```
## API Server Deployment
### FastAPI server
```python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import torchaudio
from audiocraft.models import MusicGen
import io
import base64
app = FastAPI()
# Load model at startup
model = None
@app.on_event("startup")
async def load_model():
global model
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.set_generation_params(duration=10)
class GenerateRequest(BaseModel):
prompt: str
duration: float = 10.0
temperature: float = 1.0
cfg_coef: float = 3.0
class GenerateResponse(BaseModel):
audio_base64: str
sample_rate: int
duration: float
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
model.set_generation_params(
duration=min(request.duration, 30),
temperature=request.temperature,
cfg_coef=request.cfg_coef
)
with torch.no_grad():
wav = model.generate([request.prompt])
# Convert to bytes
buffer = io.BytesIO()
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
buffer.seek(0)
audio_base64 = base64.b64encode(buffer.read()).decode()
return GenerateResponse(
audio_base64=audio_base64,
sample_rate=32000,
duration=wav.shape[-1] / 32000
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {"status": "ok", "model_loaded": model is not None}
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
```
### Batch processing service
```python
import asyncio
from concurrent.futures import ThreadPoolExecutor
import torch
from audiocraft.models import MusicGen
class MusicGenService:
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
self.model = MusicGen.get_pretrained(model_name)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.lock = asyncio.Lock()
async def generate_async(self, prompt, duration=10):
"""Async generation with thread pool."""
loop = asyncio.get_event_loop()
def _generate():
with torch.no_grad():
self.model.set_generation_params(duration=duration)
return self.model.generate([prompt])
# Run in thread pool
wav = await loop.run_in_executor(self.executor, _generate)
return wav[0].cpu()
async def generate_batch_async(self, prompts, duration=10):
"""Process multiple prompts concurrently."""
tasks = [self.generate_async(p, duration) for p in prompts]
return await asyncio.gather(*tasks)
# Usage
service = MusicGenService()
async def main():
prompts = ["jazz piano", "rock guitar", "electronic beats"]
results = await service.generate_batch_async(prompts)
return results
```
## Integration Patterns
### LangChain tool
```python
from langchain.tools import BaseTool
import torch
import torchaudio
from audiocraft.models import MusicGen
import tempfile
class MusicGeneratorTool(BaseTool):
name = "music_generator"
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
def __init__(self):
super().__init__()
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
self.model.set_generation_params(duration=15)
def _run(self, description: str) -> str:
with torch.no_grad():
wav = self.model.generate([description])
# Save to temp file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
return f"Generated music saved to: {f.name}"
async def _arun(self, description: str) -> str:
return self._run(description)
```
### Gradio with advanced controls
```python
import gradio as gr
import torch
import torchaudio
from audiocraft.models import MusicGen
models = {}
def load_model(model_size):
if model_size not in models:
model_name = f"facebook/musicgen-{model_size}"
models[model_size] = MusicGen.get_pretrained(model_name)
return models[model_size]
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
model = load_model(model_size)
model.set_generation_params(
duration=duration,
temperature=temperature,
cfg_coef=cfg_coef,
top_k=top_k
)
with torch.no_grad():
wav = model.generate([prompt])
# Save
path = "output.wav"
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
return path
demo = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Prompt", lines=3),
gr.Slider(1, 30, value=10, label="Duration (s)"),
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
],
outputs=gr.Audio(label="Generated Music"),
title="MusicGen Advanced",
allow_flagging="never"
)
demo.launch(share=True)
```
## Audio Processing Pipeline
### Post-processing chain
```python
import torch
import torchaudio
import torchaudio.transforms as T
import numpy as np
class AudioPostProcessor:
def __init__(self, sample_rate=32000):
self.sample_rate = sample_rate
def normalize(self, audio, target_db=-14.0):
"""Normalize audio to target loudness."""
rms = torch.sqrt(torch.mean(audio ** 2))
target_rms = 10 ** (target_db / 20)
gain = target_rms / (rms + 1e-8)
return audio * gain
def fade_in_out(self, audio, fade_duration=0.1):
"""Apply fade in/out."""
fade_samples = int(fade_duration * self.sample_rate)
# Create fade curves
fade_in = torch.linspace(0, 1, fade_samples)
fade_out = torch.linspace(1, 0, fade_samples)
# Apply fades
audio[..., :fade_samples] *= fade_in
audio[..., -fade_samples:] *= fade_out
return audio
def apply_reverb(self, audio, decay=0.5):
"""Apply simple reverb effect."""
impulse = torch.zeros(int(self.sample_rate * 0.5))
impulse[0] = 1.0
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
# Convolve
audio = torch.nn.functional.conv1d(
audio.unsqueeze(0),
impulse.unsqueeze(0).unsqueeze(0),
padding=len(impulse) // 2
).squeeze(0)
return audio
def process(self, audio):
"""Full processing pipeline."""
audio = self.normalize(audio)
audio = self.fade_in_out(audio)
return audio
# Usage with MusicGen
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.set_generation_params(duration=10)
wav = model.generate(["chill ambient music"])
processor = AudioPostProcessor()
wav_processed = processor.process(wav[0].cpu())
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
```
## Evaluation
### Audio quality metrics
```python
import torch
from audiocraft.metrics import CLAPTextConsistencyMetric
from audiocraft.data.audio import audio_read
def evaluate_generation(audio_path, text_prompt):
"""Evaluate generated audio quality."""
# Load audio
wav, sr = audio_read(audio_path)
# CLAP consistency (text-audio alignment)
clap_metric = CLAPTextConsistencyMetric()
clap_score = clap_metric.compute(wav, [text_prompt])
return {
"clap_score": clap_score,
"duration": wav.shape[-1] / sr
}
# Batch evaluation
def evaluate_batch(generations):
"""Evaluate multiple generations."""
results = []
for gen in generations:
result = evaluate_generation(gen["path"], gen["prompt"])
result["prompt"] = gen["prompt"]
results.append(result)
# Aggregate
avg_clap = sum(r["clap_score"] for r in results) / len(results)
return {
"individual": results,
"average_clap": avg_clap
}
```
## Model Comparison
### MusicGen variants benchmark
| Model | CLAP Score | Generation Time (10s) | VRAM |
|-------|------------|----------------------|------|
| musicgen-small | 0.35 | ~5s | 2GB |
| musicgen-medium | 0.42 | ~15s | 4GB |
| musicgen-large | 0.48 | ~30s | 8GB |
| musicgen-melody | 0.45 | ~15s | 4GB |
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
### Prompt engineering tips
```python
# Good prompts - specific and descriptive
good_prompts = [
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
"funky disco groove with slap bass, brass section, and rhythmic guitar"
]
# Bad prompts - too vague
bad_prompts = [
"nice music",
"song",
"good beat"
]
# Structure: [mood] [genre] with [instruments] at [tempo/style]
```

View file

@ -0,0 +1,504 @@
# AudioCraft Troubleshooting Guide
## Installation Issues
### Import errors
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
**Solutions**:
```bash
# Install from PyPI
pip install audiocraft
# Or from GitHub
pip install git+https://github.com/facebookresearch/audiocraft.git
# Verify installation
python -c "from audiocraft.models import MusicGen; print('OK')"
```
### FFmpeg not found
**Error**: `RuntimeError: ffmpeg not found`
**Solutions**:
```bash
# Ubuntu/Debian
sudo apt-get install ffmpeg
# macOS
brew install ffmpeg
# Windows (using conda)
conda install -c conda-forge ffmpeg
# Verify
ffmpeg -version
```
### PyTorch CUDA mismatch
**Error**: `RuntimeError: CUDA error: no kernel image is available`
**Solutions**:
```bash
# Check CUDA version
nvcc --version
python -c "import torch; print(torch.version.cuda)"
# Install matching PyTorch
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
# For CUDA 11.8
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
```
### xformers issues
**Error**: `ImportError: xformers` related errors
**Solutions**:
```bash
# Install xformers for memory efficiency
pip install xformers
# Or disable xformers
export AUDIOCRAFT_USE_XFORMERS=0
# In Python
import os
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
from audiocraft.models import MusicGen
```
## Model Loading Issues
### Out of memory during load
**Error**: `torch.cuda.OutOfMemoryError` during model loading
**Solutions**:
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Force CPU loading first
import torch
device = "cpu"
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
model = model.to("cuda")
# Use HuggingFace with device_map
from transformers import MusicgenForConditionalGeneration
model = MusicgenForConditionalGeneration.from_pretrained(
"facebook/musicgen-small",
device_map="auto"
)
```
### Download failures
**Error**: Connection errors or incomplete downloads
**Solutions**:
```python
# Set cache directory
import os
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
# Or for HuggingFace
os.environ["HF_HOME"] = "/path/to/hf_cache"
# Resume download
from huggingface_hub import snapshot_download
snapshot_download("facebook/musicgen-small", resume_download=True)
# Use local files
model = MusicGen.get_pretrained('/local/path/to/model')
```
### Wrong model type
**Error**: Loading wrong model for task
**Solutions**:
```python
# For text-to-music: use MusicGen
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# For text-to-sound: use AudioGen
from audiocraft.models import AudioGen
model = AudioGen.get_pretrained('facebook/audiogen-medium')
# For melody conditioning: use melody variant
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# For stereo: use stereo variant
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
```
## Generation Issues
### Empty or silent output
**Problem**: Generated audio is silent or very quiet
**Solutions**:
```python
import torch
# Check output
wav = model.generate(["upbeat music"])
print(f"Shape: {wav.shape}")
print(f"Max amplitude: {wav.abs().max().item()}")
print(f"Mean amplitude: {wav.abs().mean().item()}")
# If too quiet, normalize
def normalize_audio(audio, target_db=-14.0):
rms = torch.sqrt(torch.mean(audio ** 2))
target_rms = 10 ** (target_db / 20)
gain = target_rms / (rms + 1e-8)
return audio * gain
wav_normalized = normalize_audio(wav)
```
### Poor quality output
**Problem**: Generated music sounds bad or noisy
**Solutions**:
```python
# Use larger model
model = MusicGen.get_pretrained('facebook/musicgen-large')
# Adjust generation parameters
model.set_generation_params(
duration=15,
top_k=250, # Increase for more diversity
temperature=0.8, # Lower for more focused output
cfg_coef=4.0 # Increase for better text adherence
)
# Use better prompts
# Bad: "music"
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
# Try MultiBand Diffusion
from audiocraft.models import MultiBandDiffusion
mbd = MultiBandDiffusion.get_mbd_musicgen()
tokens = model.generate_tokens(["prompt"])
wav = mbd.tokens_to_wav(tokens)
```
### Generation too short
**Problem**: Audio shorter than expected
**Solutions**:
```python
# Check duration setting
model.set_generation_params(duration=30) # Set before generate
# Verify in generation
print(f"Duration setting: {model.generation_params}")
# Check output shape
wav = model.generate(["prompt"])
actual_duration = wav.shape[-1] / 32000
print(f"Actual duration: {actual_duration}s")
# Note: max duration is typically 30s
```
### Melody conditioning fails
**Error**: Issues with melody-conditioned generation
**Solutions**:
```python
import torchaudio
from audiocraft.models import MusicGen
# Load melody model (not base model)
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# Load and prepare melody
melody, sr = torchaudio.load("melody.wav")
# Resample to model sample rate if needed
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
melody = resampler(melody)
# Ensure correct shape [batch, channels, samples]
if melody.dim() == 1:
melody = melody.unsqueeze(0).unsqueeze(0)
elif melody.dim() == 2:
melody = melody.unsqueeze(0)
# Convert stereo to mono
if melody.shape[1] > 1:
melody = melody.mean(dim=1, keepdim=True)
# Generate with melody
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
```
## Memory Issues
### CUDA out of memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
```python
import torch
# Clear cache before generation
torch.cuda.empty_cache()
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Reduce duration
model.set_generation_params(duration=10) # Instead of 30
# Generate one at a time
for prompt in prompts:
wav = model.generate([prompt])
save_audio(wav)
torch.cuda.empty_cache()
# Use CPU for very large generations
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
```
### Memory leak during batch processing
**Problem**: Memory grows over time
**Solutions**:
```python
import gc
import torch
def generate_with_cleanup(model, prompts):
results = []
for prompt in prompts:
with torch.no_grad():
wav = model.generate([prompt])
results.append(wav.cpu())
# Cleanup
del wav
gc.collect()
torch.cuda.empty_cache()
return results
# Use context manager
with torch.inference_mode():
wav = model.generate(["prompt"])
```
## Audio Format Issues
### Wrong sample rate
**Problem**: Audio plays at wrong speed
**Solutions**:
```python
import torchaudio
# MusicGen outputs at 32kHz
sample_rate = 32000
# AudioGen outputs at 16kHz
sample_rate = 16000
# Always use correct rate when saving
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
# Resample if needed
resampler = torchaudio.transforms.Resample(32000, 44100)
wav_resampled = resampler(wav)
```
### Stereo/mono mismatch
**Problem**: Wrong number of channels
**Solutions**:
```python
# Check model type
print(f"Audio channels: {wav.shape}")
# Mono: [batch, 1, samples]
# Stereo: [batch, 2, samples]
# Convert mono to stereo
if wav.shape[1] == 1:
wav_stereo = wav.repeat(1, 2, 1)
# Convert stereo to mono
if wav.shape[1] == 2:
wav_mono = wav.mean(dim=1, keepdim=True)
# Use stereo model for stereo output
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
```
### Clipping and distortion
**Problem**: Audio has clipping or distortion
**Solutions**:
```python
import torch
# Check for clipping
max_val = wav.abs().max().item()
print(f"Max amplitude: {max_val}")
# Normalize to prevent clipping
if max_val > 1.0:
wav = wav / max_val
# Apply soft clipping
def soft_clip(x, threshold=0.9):
return torch.tanh(x / threshold) * threshold
wav_clipped = soft_clip(wav)
# Lower temperature during generation
model.set_generation_params(temperature=0.7) # More controlled
```
## HuggingFace Transformers Issues
### Processor errors
**Error**: Issues with MusicgenProcessor
**Solutions**:
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
# Load matching processor and model
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
# Ensure inputs are on same device
inputs = processor(
text=["prompt"],
padding=True,
return_tensors="pt"
).to("cuda")
# Check processor configuration
print(processor.tokenizer)
print(processor.feature_extractor)
```
### Generation parameter errors
**Error**: Invalid generation parameters
**Solutions**:
```python
# HuggingFace uses different parameter names
audio_values = model.generate(
**inputs,
do_sample=True, # Enable sampling
guidance_scale=3.0, # CFG (not cfg_coef)
max_new_tokens=256, # Token limit (not duration)
temperature=1.0
)
# Calculate tokens from duration
# ~50 tokens per second
duration_seconds = 10
max_tokens = duration_seconds * 50
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
```
## Performance Issues
### Slow generation
**Problem**: Generation takes too long
**Solutions**:
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Reduce duration
model.set_generation_params(duration=10)
# Use GPU
model.to("cuda")
# Enable flash attention if available
# (requires compatible hardware)
# Batch multiple prompts
prompts = ["prompt1", "prompt2", "prompt3"]
wav = model.generate(prompts) # Single batch is faster than loop
# Use compile (PyTorch 2.0+)
model.lm = torch.compile(model.lm)
```
### CPU fallback
**Problem**: Generation running on CPU instead of GPU
**Solutions**:
```python
import torch
# Check CUDA availability
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
# Explicitly move to GPU
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.to("cuda")
# Verify model device
print(f"Model device: {next(model.lm.parameters()).device}")
```
## Common Error Messages
| Error | Cause | Solution |
|-------|-------|----------|
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
## Getting Help
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
2. **HuggingFace Forums**: https://discuss.huggingface.co
3. **Paper**: https://arxiv.org/abs/2306.05284
### Reporting Issues
Include:
- Python version
- PyTorch version
- CUDA version
- AudioCraft version: `pip show audiocraft`
- Full error traceback
- Minimal reproducible code
- Hardware (GPU model, VRAM)

View file

@ -0,0 +1,256 @@
---
name: clip
description: OpenAI's model connecting vision and language. Enables zero-shot image classification, image-text matching, and cross-modal retrieval. Trained on 400M image-text pairs. Use for image search, content moderation, or vision-language tasks without fine-tuning. Best for general-purpose image understanding.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [transformers, torch, pillow]
metadata:
hermes:
tags: [Multimodal, CLIP, Vision-Language, Zero-Shot, Image Classification, OpenAI, Image Search, Cross-Modal Retrieval, Content Moderation]
---
# CLIP - Contrastive Language-Image Pre-Training
OpenAI's model that understands images from natural language.
## When to use CLIP
**Use when:**
- Zero-shot image classification (no training data needed)
- Image-text similarity/matching
- Semantic image search
- Content moderation (detect NSFW, violence)
- Visual question answering
- Cross-modal retrieval (image→text, text→image)
**Metrics**:
- **25,300+ GitHub stars**
- Trained on 400M image-text pairs
- Matches ResNet-50 on ImageNet (zero-shot)
- MIT License
**Use alternatives instead**:
- **BLIP-2**: Better captioning
- **LLaVA**: Vision-language chat
- **Segment Anything**: Image segmentation
## Quick start
### Installation
```bash
pip install git+https://github.com/openai/CLIP.git
pip install torch torchvision ftfy regex tqdm
```
### Zero-shot classification
```python
import torch
import clip
from PIL import Image
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# Load image
image = preprocess(Image.open("photo.jpg")).unsqueeze(0).to(device)
# Define possible labels
text = clip.tokenize(["a dog", "a cat", "a bird", "a car"]).to(device)
# Compute similarity
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
# Cosine similarity
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
# Print results
labels = ["a dog", "a cat", "a bird", "a car"]
for label, prob in zip(labels, probs[0]):
print(f"{label}: {prob:.2%}")
```
## Available models
```python
# Models (sorted by size)
models = [
"RN50", # ResNet-50
"RN101", # ResNet-101
"ViT-B/32", # Vision Transformer (recommended)
"ViT-B/16", # Better quality, slower
"ViT-L/14", # Best quality, slowest
]
model, preprocess = clip.load("ViT-B/32")
```
| Model | Parameters | Speed | Quality |
|-------|------------|-------|---------|
| RN50 | 102M | Fast | Good |
| ViT-B/32 | 151M | Medium | Better |
| ViT-L/14 | 428M | Slow | Best |
## Image-text similarity
```python
# Compute embeddings
image_features = model.encode_image(image)
text_features = model.encode_text(text)
# Normalize
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Cosine similarity
similarity = (image_features @ text_features.T).item()
print(f"Similarity: {similarity:.4f}")
```
## Semantic image search
```python
# Index images
image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"]
image_embeddings = []
for img_path in image_paths:
image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
with torch.no_grad():
embedding = model.encode_image(image)
embedding /= embedding.norm(dim=-1, keepdim=True)
image_embeddings.append(embedding)
image_embeddings = torch.cat(image_embeddings)
# Search with text query
query = "a sunset over the ocean"
text_input = clip.tokenize([query]).to(device)
with torch.no_grad():
text_embedding = model.encode_text(text_input)
text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
# Find most similar images
similarities = (text_embedding @ image_embeddings.T).squeeze(0)
top_k = similarities.topk(3)
for idx, score in zip(top_k.indices, top_k.values):
print(f"{image_paths[idx]}: {score:.3f}")
```
## Content moderation
```python
# Define categories
categories = [
"safe for work",
"not safe for work",
"violent content",
"graphic content"
]
text = clip.tokenize(categories).to(device)
# Check image
with torch.no_grad():
logits_per_image, _ = model(image, text)
probs = logits_per_image.softmax(dim=-1)
# Get classification
max_idx = probs.argmax().item()
max_prob = probs[0, max_idx].item()
print(f"Category: {categories[max_idx]} ({max_prob:.2%})")
```
## Batch processing
```python
# Process multiple images
images = [preprocess(Image.open(f"img{i}.jpg")) for i in range(10)]
images = torch.stack(images).to(device)
with torch.no_grad():
image_features = model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)
# Batch text
texts = ["a dog", "a cat", "a bird"]
text_tokens = clip.tokenize(texts).to(device)
with torch.no_grad():
text_features = model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Similarity matrix (10 images × 3 texts)
similarities = image_features @ text_features.T
print(similarities.shape) # (10, 3)
```
## Integration with vector databases
```python
# Store CLIP embeddings in Chroma/FAISS
import chromadb
client = chromadb.Client()
collection = client.create_collection("image_embeddings")
# Add image embeddings
for img_path, embedding in zip(image_paths, image_embeddings):
collection.add(
embeddings=[embedding.cpu().numpy().tolist()],
metadatas=[{"path": img_path}],
ids=[img_path]
)
# Query with text
query = "a sunset"
text_embedding = model.encode_text(clip.tokenize([query]))
results = collection.query(
query_embeddings=[text_embedding.cpu().numpy().tolist()],
n_results=5
)
```
## Best practices
1. **Use ViT-B/32 for most cases** - Good balance
2. **Normalize embeddings** - Required for cosine similarity
3. **Batch processing** - More efficient
4. **Cache embeddings** - Expensive to recompute
5. **Use descriptive labels** - Better zero-shot performance
6. **GPU recommended** - 10-50× faster
7. **Preprocess images** - Use provided preprocess function
## Performance
| Operation | CPU | GPU (V100) |
|-----------|-----|------------|
| Image encoding | ~200ms | ~20ms |
| Text encoding | ~50ms | ~5ms |
| Similarity compute | <1ms | <1ms |
## Limitations
1. **Not for fine-grained tasks** - Best for broad categories
2. **Requires descriptive text** - Vague labels perform poorly
3. **Biased on web data** - May have dataset biases
4. **No bounding boxes** - Whole image only
5. **Limited spatial understanding** - Position/counting weak
## Resources
- **GitHub**: https://github.com/openai/CLIP ⭐ 25,300+
- **Paper**: https://arxiv.org/abs/2103.00020
- **Colab**: https://colab.research.google.com/github/openai/clip/
- **License**: MIT

View file

@ -0,0 +1,207 @@
# CLIP Applications Guide
Practical applications and use cases for CLIP.
## Zero-shot image classification
```python
import torch
import clip
from PIL import Image
model, preprocess = clip.load("ViT-B/32")
# Define categories
categories = [
"a photo of a dog",
"a photo of a cat",
"a photo of a bird",
"a photo of a car",
"a photo of a person"
]
# Prepare image
image = preprocess(Image.open("photo.jpg")).unsqueeze(0)
text = clip.tokenize(categories)
# Classify
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logits_per_image, _ = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
# Print results
for category, prob in zip(categories, probs[0]):
print(f"{category}: {prob:.2%}")
```
## Semantic image search
```python
# Index images
image_database = []
image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"]
for img_path in image_paths:
image = preprocess(Image.open(img_path)).unsqueeze(0)
with torch.no_grad():
features = model.encode_image(image)
features /= features.norm(dim=-1, keepdim=True)
image_database.append((img_path, features))
# Search with text
query = "a sunset over mountains"
text_input = clip.tokenize([query])
with torch.no_grad():
text_features = model.encode_text(text_input)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Find matches
similarities = []
for img_path, img_features in image_database:
similarity = (text_features @ img_features.T).item()
similarities.append((img_path, similarity))
# Sort by similarity
similarities.sort(key=lambda x: x[1], reverse=True)
for img_path, score in similarities[:3]:
print(f"{img_path}: {score:.3f}")
```
## Content moderation
```python
# Define safety categories
categories = [
"safe for work content",
"not safe for work content",
"violent or graphic content",
"hate speech or offensive content",
"spam or misleading content"
]
text = clip.tokenize(categories)
# Check image
with torch.no_grad():
logits, _ = model(image, text)
probs = logits.softmax(dim=-1)
# Get classification
max_idx = probs.argmax().item()
confidence = probs[0, max_idx].item()
if confidence > 0.7:
print(f"Classified as: {categories[max_idx]} ({confidence:.2%})")
else:
print(f"Uncertain classification (confidence: {confidence:.2%})")
```
## Image-to-text retrieval
```python
# Text database
captions = [
"A beautiful sunset over the ocean",
"A cute dog playing in the park",
"A modern city skyline at night",
"A delicious pizza with toppings"
]
# Encode captions
caption_features = []
for caption in captions:
text = clip.tokenize([caption])
with torch.no_grad():
features = model.encode_text(text)
features /= features.norm(dim=-1, keepdim=True)
caption_features.append(features)
caption_features = torch.cat(caption_features)
# Find matching captions for image
with torch.no_grad():
image_features = model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
similarities = (image_features @ caption_features.T).squeeze(0)
top_k = similarities.topk(3)
for idx, score in zip(top_k.indices, top_k.values):
print(f"{captions[idx]}: {score:.3f}")
```
## Visual question answering
```python
# Create yes/no questions
image = preprocess(Image.open("photo.jpg")).unsqueeze(0)
questions = [
"a photo showing people",
"a photo showing animals",
"a photo taken indoors",
"a photo taken outdoors",
"a photo taken during daytime",
"a photo taken at night"
]
text = clip.tokenize(questions)
with torch.no_grad():
logits, _ = model(image, text)
probs = logits.softmax(dim=-1)
# Answer questions
for question, prob in zip(questions, probs[0]):
answer = "Yes" if prob > 0.5 else "No"
print(f"{question}: {answer} ({prob:.2%})")
```
## Image deduplication
```python
# Detect duplicate/similar images
def compute_similarity(img1_path, img2_path):
img1 = preprocess(Image.open(img1_path)).unsqueeze(0)
img2 = preprocess(Image.open(img2_path)).unsqueeze(0)
with torch.no_grad():
feat1 = model.encode_image(img1)
feat2 = model.encode_image(img2)
feat1 /= feat1.norm(dim=-1, keepdim=True)
feat2 /= feat2.norm(dim=-1, keepdim=True)
similarity = (feat1 @ feat2.T).item()
return similarity
# Check for duplicates
threshold = 0.95
image_pairs = [("img1.jpg", "img2.jpg"), ("img1.jpg", "img3.jpg")]
for img1, img2 in image_pairs:
sim = compute_similarity(img1, img2)
if sim > threshold:
print(f"{img1} and {img2} are duplicates (similarity: {sim:.3f})")
```
## Best practices
1. **Use descriptive labels** - "a photo of X" works better than just "X"
2. **Normalize embeddings** - Always normalize for cosine similarity
3. **Batch processing** - Process multiple images/texts together
4. **Cache embeddings** - Expensive to recompute
5. **Set appropriate thresholds** - Test on validation data
6. **Use GPU** - 10-50× faster than CPU
7. **Consider model size** - ViT-B/32 good default, ViT-L/14 for best quality
## Resources
- **Paper**: https://arxiv.org/abs/2103.00020
- **GitHub**: https://github.com/openai/CLIP
- **Colab**: https://colab.research.google.com/github/openai/clip/

View file

@ -0,0 +1,307 @@
---
name: llava
description: Large Language and Vision Assistant. Enables visual instruction tuning and image-based conversations. Combines CLIP vision encoder with Vicuna/LLaMA language models. Supports multi-turn image chat, visual question answering, and instruction following. Use for vision-language chatbots or image understanding tasks. Best for conversational image analysis.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [transformers, torch, pillow]
metadata:
hermes:
tags: [LLaVA, Vision-Language, Multimodal, Visual Question Answering, Image Chat, CLIP, Vicuna, Conversational AI, Instruction Tuning, VQA]
---
# LLaVA - Large Language and Vision Assistant
Open-source vision-language model for conversational image understanding.
## When to use LLaVA
**Use when:**
- Building vision-language chatbots
- Visual question answering (VQA)
- Image description and captioning
- Multi-turn image conversations
- Visual instruction following
- Document understanding with images
**Metrics**:
- **23,000+ GitHub stars**
- GPT-4V level capabilities (targeted)
- Apache 2.0 License
- Multiple model sizes (7B-34B params)
**Use alternatives instead**:
- **GPT-4V**: Highest quality, API-based
- **CLIP**: Simple zero-shot classification
- **BLIP-2**: Better for captioning only
- **Flamingo**: Research, not open-source
## Quick start
### Installation
```bash
# Clone repository
git clone https://github.com/haotian-liu/LLaVA
cd LLaVA
# Install
pip install -e .
```
### Basic usage
```python
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from PIL import Image
import torch
# Load model
model_path = "liuhaotian/llava-v1.5-7b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path)
)
# Load image
image = Image.open("image.jpg")
image_tensor = process_images([image], image_processor, model.config)
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
# Create conversation
conv = conv_templates["llava_v1"].copy()
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# Generate response
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
max_new_tokens=512
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
print(response)
```
## Available models
| Model | Parameters | VRAM | Quality |
|-------|------------|------|---------|
| LLaVA-v1.5-7B | 7B | ~14 GB | Good |
| LLaVA-v1.5-13B | 13B | ~28 GB | Better |
| LLaVA-v1.6-34B | 34B | ~70 GB | Best |
```python
# Load different models
model_7b = "liuhaotian/llava-v1.5-7b"
model_13b = "liuhaotian/llava-v1.5-13b"
model_34b = "liuhaotian/llava-v1.6-34b"
# 4-bit quantization for lower VRAM
load_4bit = True # Reduces VRAM by ~4×
```
## CLI usage
```bash
# Single image query
python -m llava.serve.cli \
--model-path liuhaotian/llava-v1.5-7b \
--image-file image.jpg \
--query "What is in this image?"
# Multi-turn conversation
python -m llava.serve.cli \
--model-path liuhaotian/llava-v1.5-7b \
--image-file image.jpg
# Then type questions interactively
```
## Web UI (Gradio)
```bash
# Launch Gradio interface
python -m llava.serve.gradio_web_server \
--model-path liuhaotian/llava-v1.5-7b \
--load-4bit # Optional: reduce VRAM
# Access at http://localhost:7860
```
## Multi-turn conversations
```python
# Initialize conversation
conv = conv_templates["llava_v1"].copy()
# Turn 1
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
conv.append_message(conv.roles[1], None)
response1 = generate(conv, model, image) # "A dog playing in a park"
# Turn 2
conv.messages[-1][1] = response1 # Add previous response
conv.append_message(conv.roles[0], "What breed is the dog?")
conv.append_message(conv.roles[1], None)
response2 = generate(conv, model, image) # "Golden Retriever"
# Turn 3
conv.messages[-1][1] = response2
conv.append_message(conv.roles[0], "What time of day is it?")
conv.append_message(conv.roles[1], None)
response3 = generate(conv, model, image)
```
## Common tasks
### Image captioning
```python
question = "Describe this image in detail."
response = ask(model, image, question)
```
### Visual question answering
```python
question = "How many people are in the image?"
response = ask(model, image, question)
```
### Object detection (textual)
```python
question = "List all the objects you can see in this image."
response = ask(model, image, question)
```
### Scene understanding
```python
question = "What is happening in this scene?"
response = ask(model, image, question)
```
### Document understanding
```python
question = "What is the main topic of this document?"
response = ask(model, document_image, question)
```
## Training custom model
```bash
# Stage 1: Feature alignment (558K image-caption pairs)
bash scripts/v1_5/pretrain.sh
# Stage 2: Visual instruction tuning (150K instruction data)
bash scripts/v1_5/finetune.sh
```
## Quantization (reduce VRAM)
```python
# 4-bit quantization
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path="liuhaotian/llava-v1.5-13b",
model_base=None,
model_name=get_model_name_from_path("liuhaotian/llava-v1.5-13b"),
load_4bit=True # Reduces VRAM ~4×
)
# 8-bit quantization
load_8bit=True # Reduces VRAM ~2×
```
## Best practices
1. **Start with 7B model** - Good quality, manageable VRAM
2. **Use 4-bit quantization** - Reduces VRAM significantly
3. **GPU required** - CPU inference extremely slow
4. **Clear prompts** - Specific questions get better answers
5. **Multi-turn conversations** - Maintain conversation context
6. **Temperature 0.2-0.7** - Balance creativity/consistency
7. **max_new_tokens 512-1024** - For detailed responses
8. **Batch processing** - Process multiple images sequentially
## Performance
| Model | VRAM (FP16) | VRAM (4-bit) | Speed (tokens/s) |
|-------|-------------|--------------|------------------|
| 7B | ~14 GB | ~4 GB | ~20 |
| 13B | ~28 GB | ~8 GB | ~12 |
| 34B | ~70 GB | ~18 GB | ~5 |
*On A100 GPU*
## Benchmarks
LLaVA achieves competitive scores on:
- **VQAv2**: 78.5%
- **GQA**: 62.0%
- **MM-Vet**: 35.4%
- **MMBench**: 64.3%
## Limitations
1. **Hallucinations** - May describe things not in image
2. **Spatial reasoning** - Struggles with precise locations
3. **Small text** - Difficulty reading fine print
4. **Object counting** - Imprecise for many objects
5. **VRAM requirements** - Need powerful GPU
6. **Inference speed** - Slower than CLIP
## Integration with frameworks
### LangChain
```python
from langchain.llms.base import LLM
class LLaVALLM(LLM):
def _call(self, prompt, stop=None):
# Custom LLaVA inference
return response
llm = LLaVALLM()
```
### Gradio App
```python
import gradio as gr
def chat(image, text, history):
response = ask_llava(model, image, text)
return response
demo = gr.ChatInterface(
chat,
additional_inputs=[gr.Image(type="pil")],
title="LLaVA Chat"
)
demo.launch()
```
## Resources
- **GitHub**: https://github.com/haotian-liu/LLaVA ⭐ 23,000+
- **Paper**: https://arxiv.org/abs/2304.08485
- **Demo**: https://llava.hliu.cc
- **Models**: https://huggingface.co/liuhaotian
- **License**: Apache 2.0

View file

@ -0,0 +1,197 @@
# LLaVA Training Guide
Guide to training and fine-tuning LLaVA models.
## Training stages
### Stage 1: Feature alignment (Pretraining)
**Purpose**: Align vision encoder with language model
**Data**: 558K image-caption pairs (CC3M subset)
```bash
# Download pretrained projector or train from scratch
bash scripts/v1_5/pretrain.sh
```
**Configuration:**
- Base model: Vicuna-7B or LLaMA-2-7B
- Vision encoder: CLIP ViT-L/14
- Training time: ~20 hours on 8× A100
### Stage 2: Visual instruction tuning
**Purpose**: Teach model to follow visual instructions
**Data**: 150K GPT-generated multimodal instruction data
```bash
# Fine-tune with instruction data
bash scripts/v1_5/finetune.sh
```
**Configuration:**
- Epochs: 1
- Batch size: 128 (across 8 GPUs)
- Learning rate: 2e-5
- Training time: ~24 hours on 8× A100
## Data format
### Instruction data format
```json
[
{
"id": "001",
"image": "path/to/image.jpg",
"conversations": [
{
"from": "human",
"value": "<image>\nWhat is in this image?"
},
{
"from": "gpt",
"value": "The image shows a dog playing in a park."
},
{
"from": "human",
"value": "What breed is the dog?"
},
{
"from": "gpt",
"value": "It appears to be a Golden Retriever."
}
]
}
]
```
## Fine-tuning on custom data
### Prepare your data
```python
import json
# Create instruction data
data = []
for image_path, qa_pairs in your_dataset:
conversations = []
for q, a in qa_pairs:
conversations.append({"from": "human", "value": f"<image>\n{q}"})
conversations.append({"from": "gpt", "value": a})
data.append({
"id": str(len(data)),
"image": image_path,
"conversations": conversations
})
# Save
with open("custom_data.json", "w") as f:
json.dump(data, f, indent=2)
```
### Fine-tune script
```bash
#!/bin/bash
# Set paths
DATA_PATH="custom_data.json"
IMAGE_FOLDER="path/to/images"
MODEL_PATH="liuhaotian/llava-v1.5-7b"
OUTPUT_DIR="./checkpoints/llava-custom"
# Fine-tune
deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path $MODEL_PATH \
--version v1 \
--data_path $DATA_PATH \
--image_folder $IMAGE_FOLDER \
--vision_tower openai/clip-vit-large-patch14-336 \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir $OUTPUT_DIR \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
```
## LoRA fine-tuning (memory efficient)
```python
from peft import LoraConfig, get_peft_model
# LoRA config
lora_config = LoraConfig(
r=8, # LoRA rank
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Apply LoRA
model = get_peft_model(base_model, lora_config)
# Train with much lower memory
```
## Hardware requirements
### Full fine-tuning
- **7B model**: 8× A100 (40GB)
- **13B model**: 8× A100 (80GB)
- **Training time**: 20-48 hours
### LoRA fine-tuning
- **7B model**: 1× A100 (40GB)
- **13B model**: 2× A100 (40GB)
- **Training time**: 10-24 hours
## Best practices
1. **Start with pretrained** - Don't train from scratch
2. **Use LoRA for efficiency** - 10× less memory
3. **Quality over quantity** - 1K high-quality > 10K low-quality
4. **Multi-turn conversations** - More engaging than single Q&A
5. **Diverse images** - Cover different scenarios
6. **Clear instructions** - Specific questions get better answers
7. **Monitor loss** - Should decrease smoothly
8. **Save checkpoints** - Training can fail
9. **Test regularly** - Validate on held-out set
10. **Use DeepSpeed** - For multi-GPU training
## Resources
- **Training script**: https://github.com/haotian-liu/LLaVA/tree/main/scripts
- **Data format**: https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md
- **Paper**: https://arxiv.org/abs/2304.08485

View file

@ -0,0 +1,503 @@
---
name: segment-anything-model
description: Foundation model for image segmentation with zero-shot transfer. Use when you need to segment any object in images using points, boxes, or masks as prompts, or automatically generate all object masks in an image.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [segment-anything, transformers>=4.30.0, torch>=1.7.0]
metadata:
hermes:
tags: [Multimodal, Image Segmentation, Computer Vision, SAM, Zero-Shot]
---
# Segment Anything Model (SAM)
Comprehensive guide to using Meta AI's Segment Anything Model for zero-shot image segmentation.
## When to use SAM
**Use SAM when:**
- Need to segment any object in images without task-specific training
- Building interactive annotation tools with point/box prompts
- Generating training data for other vision models
- Need zero-shot transfer to new image domains
- Building object detection/segmentation pipelines
- Processing medical, satellite, or domain-specific images
**Key features:**
- **Zero-shot segmentation**: Works on any image domain without fine-tuning
- **Flexible prompts**: Points, bounding boxes, or previous masks
- **Automatic segmentation**: Generate all object masks automatically
- **High quality**: Trained on 1.1 billion masks from 11 million images
- **Multiple model sizes**: ViT-B (fastest), ViT-L, ViT-H (most accurate)
- **ONNX export**: Deploy in browsers and edge devices
**Use alternatives instead:**
- **YOLO/Detectron2**: For real-time object detection with classes
- **Mask2Former**: For semantic/panoptic segmentation with categories
- **GroundingDINO + SAM**: For text-prompted segmentation
- **SAM 2**: For video segmentation tasks
## Quick start
### Installation
```bash
# From GitHub
pip install git+https://github.com/facebookresearch/segment-anything.git
# Optional dependencies
pip install opencv-python pycocotools matplotlib
# Or use HuggingFace transformers
pip install transformers
```
### Download checkpoints
```bash
# ViT-H (largest, most accurate) - 2.4GB
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
# ViT-L (medium) - 1.2GB
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
# ViT-B (smallest, fastest) - 375MB
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
```
### Basic usage with SamPredictor
```python
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
# Load model
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(device="cuda")
# Create predictor
predictor = SamPredictor(sam)
# Set image (computes embeddings once)
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
# Predict with point prompts
input_point = np.array([[500, 375]]) # (x, y) coordinates
input_label = np.array([1]) # 1 = foreground, 0 = background
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True # Returns 3 mask options
)
# Select best mask
best_mask = masks[np.argmax(scores)]
```
### HuggingFace Transformers
```python
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model.to("cuda")
# Process image with point prompt
image = Image.open("image.jpg")
input_points = [[[450, 600]]] # Batch of points
inputs = processor(image, input_points=input_points, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Generate masks
with torch.no_grad():
outputs = model(**inputs)
# Post-process masks to original size
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
```
## Core concepts
### Model architecture
```
SAM Architecture:
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Image Encoder │────▶│ Prompt Encoder │────▶│ Mask Decoder │
│ (ViT) │ │ (Points/Boxes) │ │ (Transformer) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
Image Embeddings Prompt Embeddings Masks + IoU
(computed once) (per prompt) predictions
```
### Model variants
| Model | Checkpoint | Size | Speed | Accuracy |
|-------|------------|------|-------|----------|
| ViT-H | `vit_h` | 2.4 GB | Slowest | Best |
| ViT-L | `vit_l` | 1.2 GB | Medium | Good |
| ViT-B | `vit_b` | 375 MB | Fastest | Good |
### Prompt types
| Prompt | Description | Use Case |
|--------|-------------|----------|
| Point (foreground) | Click on object | Single object selection |
| Point (background) | Click outside object | Exclude regions |
| Bounding box | Rectangle around object | Larger objects |
| Previous mask | Low-res mask input | Iterative refinement |
## Interactive segmentation
### Point prompts
```python
# Single foreground point
input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
# Multiple points (foreground + background)
input_points = np.array([[500, 375], [600, 400], [450, 300]])
input_labels = np.array([1, 1, 0]) # 2 foreground, 1 background
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False # Single mask when prompts are clear
)
```
### Box prompts
```python
# Bounding box [x1, y1, x2, y2]
input_box = np.array([425, 600, 700, 875])
masks, scores, logits = predictor.predict(
box=input_box,
multimask_output=False
)
```
### Combined prompts
```python
# Box + points for precise control
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
box=np.array([400, 300, 700, 600]),
multimask_output=False
)
```
### Iterative refinement
```python
# Initial prediction
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
# Refine with additional point using previous mask
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375], [550, 400]]),
point_labels=np.array([1, 0]), # Add background point
mask_input=logits[np.argmax(scores)][None, :, :], # Use best mask
multimask_output=False
)
```
## Automatic mask generation
### Basic automatic segmentation
```python
from segment_anything import SamAutomaticMaskGenerator
# Create generator
mask_generator = SamAutomaticMaskGenerator(sam)
# Generate all masks
masks = mask_generator.generate(image)
# Each mask contains:
# - segmentation: binary mask
# - bbox: [x, y, w, h]
# - area: pixel count
# - predicted_iou: quality score
# - stability_score: robustness score
# - point_coords: generating point
```
### Customized generation
```python
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32, # Grid density (more = more masks)
pred_iou_thresh=0.88, # Quality threshold
stability_score_thresh=0.95, # Stability threshold
crop_n_layers=1, # Multi-scale crops
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # Remove tiny masks
)
masks = mask_generator.generate(image)
```
### Filtering masks
```python
# Sort by area (largest first)
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
# Filter by predicted IoU
high_quality = [m for m in masks if m['predicted_iou'] > 0.9]
# Filter by stability score
stable_masks = [m for m in masks if m['stability_score'] > 0.95]
```
## Batched inference
### Multiple images
```python
# Process multiple images efficiently
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
all_masks = []
for image in images:
predictor.set_image(image)
masks, _, _ = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks)
```
### Multiple prompts per image
```python
# Process multiple prompts efficiently (one image encoding)
predictor.set_image(image)
# Batch of point prompts
points = [
np.array([[100, 100]]),
np.array([[200, 200]]),
np.array([[300, 300]])
]
all_masks = []
for point in points:
masks, scores, _ = predictor.predict(
point_coords=point,
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks[np.argmax(scores)])
```
## ONNX deployment
### Export model
```bash
python scripts/export_onnx_model.py \
--checkpoint sam_vit_h_4b8939.pth \
--model-type vit_h \
--output sam_onnx.onnx \
--return-single-mask
```
### Use ONNX model
```python
import onnxruntime
# Load ONNX model
ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")
# Run inference (image embeddings computed separately)
masks = ort_session.run(
None,
{
"image_embeddings": image_embeddings,
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
"has_mask_input": np.array([0], dtype=np.float32),
"orig_im_size": np.array([h, w], dtype=np.float32)
}
)
```
## Common workflows
### Workflow 1: Annotation tool
```python
import cv2
# Load model
predictor = SamPredictor(sam)
predictor.set_image(image)
def on_click(event, x, y, flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
# Foreground point
masks, scores, _ = predictor.predict(
point_coords=np.array([[x, y]]),
point_labels=np.array([1]),
multimask_output=True
)
# Display best mask
display_mask(masks[np.argmax(scores)])
```
### Workflow 2: Object extraction
```python
def extract_object(image, point):
"""Extract object at point with transparent background."""
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=np.array([point]),
point_labels=np.array([1]),
multimask_output=True
)
best_mask = masks[np.argmax(scores)]
# Create RGBA output
rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
rgba[:, :, :3] = image
rgba[:, :, 3] = best_mask * 255
return rgba
```
### Workflow 3: Medical image segmentation
```python
# Process medical images (grayscale to RGB)
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE)
rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
predictor.set_image(rgb_image)
# Segment region of interest
masks, scores, _ = predictor.predict(
box=np.array([x1, y1, x2, y2]), # ROI bounding box
multimask_output=True
)
```
## Output format
### Mask data structure
```python
# SamAutomaticMaskGenerator output
{
"segmentation": np.ndarray, # H×W binary mask
"bbox": [x, y, w, h], # Bounding box
"area": int, # Pixel count
"predicted_iou": float, # 0-1 quality score
"stability_score": float, # 0-1 robustness score
"crop_box": [x, y, w, h], # Generation crop region
"point_coords": [[x, y]], # Input point
}
```
### COCO RLE format
```python
from pycocotools import mask as mask_utils
# Encode mask to RLE
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
rle["counts"] = rle["counts"].decode("utf-8")
# Decode RLE to mask
decoded_mask = mask_utils.decode(rle)
```
## Performance optimization
### GPU memory
```python
# Use smaller model for limited VRAM
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
# Process images in batches
# Clear CUDA cache between large batches
torch.cuda.empty_cache()
```
### Speed optimization
```python
# Use half precision
sam = sam.half()
# Reduce points for automatic generation
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16, # Default is 32
)
# Use ONNX for deployment
# Export with --return-single-mask for faster inference
```
## Common issues
| Issue | Solution |
|-------|----------|
| Out of memory | Use ViT-B model, reduce image size |
| Slow inference | Use ViT-B, reduce points_per_side |
| Poor mask quality | Try different prompts, use box + points |
| Edge artifacts | Use stability_score filtering |
| Small objects missed | Increase points_per_side |
## References
- **[Advanced Usage](references/advanced-usage.md)** - Batching, fine-tuning, integration
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
## Resources
- **GitHub**: https://github.com/facebookresearch/segment-anything
- **Paper**: https://arxiv.org/abs/2304.02643
- **Demo**: https://segment-anything.com
- **SAM 2 (Video)**: https://github.com/facebookresearch/segment-anything-2
- **HuggingFace**: https://huggingface.co/facebook/sam-vit-huge

View file

@ -0,0 +1,589 @@
# Segment Anything Advanced Usage Guide
## SAM 2 (Video Segmentation)
### Overview
SAM 2 extends SAM to video segmentation with streaming memory architecture:
```bash
pip install git+https://github.com/facebookresearch/segment-anything-2.git
```
### Video segmentation
```python
from sam2.build_sam import build_sam2_video_predictor
predictor = build_sam2_video_predictor("sam2_hiera_l.yaml", "sam2_hiera_large.pt")
# Initialize with video
predictor.init_state(video_path="video.mp4")
# Add prompt on first frame
predictor.add_new_points(
frame_idx=0,
obj_id=1,
points=[[100, 200]],
labels=[1]
)
# Propagate through video
for frame_idx, masks in predictor.propagate_in_video():
# masks contains segmentation for all tracked objects
process_frame(frame_idx, masks)
```
### SAM 2 vs SAM comparison
| Feature | SAM | SAM 2 |
|---------|-----|-------|
| Input | Images only | Images + Videos |
| Architecture | ViT + Decoder | Hiera + Memory |
| Memory | Per-image | Streaming memory bank |
| Tracking | No | Yes, across frames |
| Models | ViT-B/L/H | Hiera-T/S/B+/L |
## Grounded SAM (Text-Prompted Segmentation)
### Setup
```bash
pip install groundingdino-py
pip install git+https://github.com/facebookresearch/segment-anything.git
```
### Text-to-mask pipeline
```python
from groundingdino.util.inference import load_model, predict
from segment_anything import sam_model_registry, SamPredictor
import cv2
# Load Grounding DINO
grounding_model = load_model("groundingdino_swint_ogc.pth", "GroundingDINO_SwinT_OGC.py")
# Load SAM
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)
def text_to_mask(image, text_prompt, box_threshold=0.3, text_threshold=0.25):
"""Generate masks from text description."""
# Get bounding boxes from text
boxes, logits, phrases = predict(
model=grounding_model,
image=image,
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
# Generate masks with SAM
predictor.set_image(image)
masks = []
for box in boxes:
# Convert normalized box to pixel coordinates
h, w = image.shape[:2]
box_pixels = box * np.array([w, h, w, h])
mask, score, _ = predictor.predict(
box=box_pixels,
multimask_output=False
)
masks.append(mask[0])
return masks, boxes, phrases
# Usage
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks, boxes, phrases = text_to_mask(image, "person . dog . car")
```
## Batched Processing
### Efficient multi-image processing
```python
import torch
from segment_anything import SamPredictor, sam_model_registry
class BatchedSAM:
def __init__(self, checkpoint, model_type="vit_h", device="cuda"):
self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
self.sam.to(device)
self.predictor = SamPredictor(self.sam)
self.device = device
def process_batch(self, images, prompts):
"""Process multiple images with corresponding prompts."""
results = []
for image, prompt in zip(images, prompts):
self.predictor.set_image(image)
if "point" in prompt:
masks, scores, _ = self.predictor.predict(
point_coords=prompt["point"],
point_labels=prompt["label"],
multimask_output=True
)
elif "box" in prompt:
masks, scores, _ = self.predictor.predict(
box=prompt["box"],
multimask_output=False
)
results.append({
"masks": masks,
"scores": scores,
"best_mask": masks[np.argmax(scores)]
})
return results
# Usage
batch_sam = BatchedSAM("sam_vit_h_4b8939.pth")
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
prompts = [{"point": np.array([[100, 100]]), "label": np.array([1])} for _ in range(10)]
results = batch_sam.process_batch(images, prompts)
```
### Parallel automatic mask generation
```python
from concurrent.futures import ThreadPoolExecutor
from segment_anything import SamAutomaticMaskGenerator
def generate_masks_parallel(images, num_workers=4):
"""Generate masks for multiple images in parallel."""
# Note: Each worker needs its own model instance
def worker_init():
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
return SamAutomaticMaskGenerator(sam)
generators = [worker_init() for _ in range(num_workers)]
def process_image(args):
idx, image = args
generator = generators[idx % num_workers]
return generator.generate(image)
with ThreadPoolExecutor(max_workers=num_workers) as executor:
results = list(executor.map(process_image, enumerate(images)))
return results
```
## Custom Integration
### FastAPI service
```python
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
import numpy as np
import cv2
import io
app = FastAPI()
# Load model once
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to("cuda")
predictor = SamPredictor(sam)
class PointPrompt(BaseModel):
x: int
y: int
label: int = 1
@app.post("/segment/point")
async def segment_with_point(
file: UploadFile = File(...),
points: list[PointPrompt] = []
):
# Read image
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Set image
predictor.set_image(image)
# Prepare prompts
point_coords = np.array([[p.x, p.y] for p in points])
point_labels = np.array([p.label for p in points])
# Generate masks
masks, scores, _ = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=True
)
best_idx = np.argmax(scores)
return {
"mask": masks[best_idx].tolist(),
"score": float(scores[best_idx]),
"all_scores": scores.tolist()
}
@app.post("/segment/auto")
async def segment_automatic(file: UploadFile = File(...)):
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
return {
"num_masks": len(masks),
"masks": [
{
"bbox": m["bbox"],
"area": m["area"],
"predicted_iou": m["predicted_iou"],
"stability_score": m["stability_score"]
}
for m in masks
]
}
```
### Gradio interface
```python
import gradio as gr
import numpy as np
# Load model
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)
def segment_image(image, evt: gr.SelectData):
"""Segment object at clicked point."""
predictor.set_image(image)
point = np.array([[evt.index[0], evt.index[1]]])
label = np.array([1])
masks, scores, _ = predictor.predict(
point_coords=point,
point_labels=label,
multimask_output=True
)
best_mask = masks[np.argmax(scores)]
# Overlay mask on image
overlay = image.copy()
overlay[best_mask] = overlay[best_mask] * 0.5 + np.array([255, 0, 0]) * 0.5
return overlay
with gr.Blocks() as demo:
gr.Markdown("# SAM Interactive Segmentation")
gr.Markdown("Click on an object to segment it")
with gr.Row():
input_image = gr.Image(label="Input Image", interactive=True)
output_image = gr.Image(label="Segmented Image")
input_image.select(segment_image, inputs=[input_image], outputs=[output_image])
demo.launch()
```
## Fine-Tuning SAM
### LoRA fine-tuning (experimental)
```python
from peft import LoraConfig, get_peft_model
from transformers import SamModel
# Load model
model = SamModel.from_pretrained("facebook/sam-vit-base")
# Configure LoRA
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["qkv"], # Attention layers
lora_dropout=0.1,
bias="none",
)
# Apply LoRA
model = get_peft_model(model, lora_config)
# Training loop (simplified)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
outputs = model(
pixel_values=batch["pixel_values"],
input_points=batch["input_points"],
input_labels=batch["input_labels"]
)
# Custom loss (e.g., IoU loss with ground truth)
loss = compute_loss(outputs.pred_masks, batch["gt_masks"])
loss.backward()
optimizer.step()
optimizer.zero_grad()
```
### MedSAM (Medical imaging)
```python
# MedSAM is a fine-tuned SAM for medical images
# https://github.com/bowang-lab/MedSAM
from segment_anything import sam_model_registry, SamPredictor
import torch
# Load MedSAM checkpoint
medsam = sam_model_registry["vit_b"](checkpoint="medsam_vit_b.pth")
medsam.to("cuda")
predictor = SamPredictor(medsam)
# Process medical image
# Convert grayscale to RGB if needed
medical_image = cv2.imread("ct_scan.png", cv2.IMREAD_GRAYSCALE)
rgb_image = np.stack([medical_image] * 3, axis=-1)
predictor.set_image(rgb_image)
# Segment with box prompt (common for medical imaging)
masks, scores, _ = predictor.predict(
box=np.array([x1, y1, x2, y2]),
multimask_output=False
)
```
## Advanced Mask Processing
### Mask refinement
```python
import cv2
from scipy import ndimage
def refine_mask(mask, kernel_size=5, iterations=2):
"""Refine mask with morphological operations."""
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
# Close small holes
closed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iterations)
# Remove small noise
opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel, iterations=iterations)
return opened.astype(bool)
def fill_holes(mask):
"""Fill holes in mask."""
filled = ndimage.binary_fill_holes(mask)
return filled
def remove_small_regions(mask, min_area=100):
"""Remove small disconnected regions."""
labeled, num_features = ndimage.label(mask)
sizes = ndimage.sum(mask, labeled, range(1, num_features + 1))
# Keep only regions larger than min_area
mask_clean = np.zeros_like(mask)
for i, size in enumerate(sizes, 1):
if size >= min_area:
mask_clean[labeled == i] = True
return mask_clean
```
### Mask to polygon conversion
```python
import cv2
def mask_to_polygons(mask, epsilon_factor=0.01):
"""Convert binary mask to polygon coordinates."""
contours, _ = cv2.findContours(
mask.astype(np.uint8),
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE
)
polygons = []
for contour in contours:
epsilon = epsilon_factor * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
polygon = approx.squeeze().tolist()
if len(polygon) >= 3: # Valid polygon
polygons.append(polygon)
return polygons
def polygons_to_mask(polygons, height, width):
"""Convert polygons back to binary mask."""
mask = np.zeros((height, width), dtype=np.uint8)
for polygon in polygons:
pts = np.array(polygon, dtype=np.int32)
cv2.fillPoly(mask, [pts], 1)
return mask.astype(bool)
```
### Multi-scale segmentation
```python
def multiscale_segment(image, predictor, point, scales=[0.5, 1.0, 2.0]):
"""Generate masks at multiple scales and combine."""
h, w = image.shape[:2]
masks_all = []
for scale in scales:
# Resize image
new_h, new_w = int(h * scale), int(w * scale)
scaled_image = cv2.resize(image, (new_w, new_h))
scaled_point = (point * scale).astype(int)
# Segment
predictor.set_image(scaled_image)
masks, scores, _ = predictor.predict(
point_coords=scaled_point.reshape(1, 2),
point_labels=np.array([1]),
multimask_output=True
)
# Resize mask back
best_mask = masks[np.argmax(scores)]
original_mask = cv2.resize(best_mask.astype(np.uint8), (w, h)) > 0.5
masks_all.append(original_mask)
# Combine masks (majority voting)
combined = np.stack(masks_all, axis=0)
final_mask = np.sum(combined, axis=0) >= len(scales) // 2 + 1
return final_mask
```
## Performance Optimization
### TensorRT acceleration
```python
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
def export_to_tensorrt(onnx_path, engine_path, fp16=True):
"""Convert ONNX model to TensorRT engine."""
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30 # 1GB
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
engine = builder.build_engine(network, config)
with open(engine_path, 'wb') as f:
f.write(engine.serialize())
return engine
```
### Memory-efficient inference
```python
class MemoryEfficientSAM:
def __init__(self, checkpoint, model_type="vit_b"):
self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
self.sam.eval()
self.predictor = None
def __enter__(self):
self.sam.to("cuda")
self.predictor = SamPredictor(self.sam)
return self
def __exit__(self, *args):
self.sam.to("cpu")
torch.cuda.empty_cache()
def segment(self, image, points, labels):
self.predictor.set_image(image)
masks, scores, _ = self.predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True
)
return masks, scores
# Usage with context manager (auto-cleanup)
with MemoryEfficientSAM("sam_vit_b_01ec64.pth") as sam:
masks, scores = sam.segment(image, points, labels)
# CUDA memory freed automatically
```
## Dataset Generation
### Create segmentation dataset
```python
import json
def generate_dataset(images_dir, output_dir, mask_generator):
"""Generate segmentation dataset from images."""
annotations = []
for img_path in Path(images_dir).glob("*.jpg"):
image = cv2.imread(str(img_path))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Generate masks
masks = mask_generator.generate(image)
# Filter high-quality masks
good_masks = [m for m in masks if m["predicted_iou"] > 0.9]
# Save annotations
for i, mask_data in enumerate(good_masks):
annotation = {
"image_id": img_path.stem,
"mask_id": i,
"bbox": mask_data["bbox"],
"area": mask_data["area"],
"segmentation": mask_to_rle(mask_data["segmentation"]),
"predicted_iou": mask_data["predicted_iou"],
"stability_score": mask_data["stability_score"]
}
annotations.append(annotation)
# Save dataset
with open(output_dir / "annotations.json", "w") as f:
json.dump(annotations, f)
return annotations
```

View file

@ -0,0 +1,484 @@
# Segment Anything Troubleshooting Guide
## Installation Issues
### CUDA not available
**Error**: `RuntimeError: CUDA not available`
**Solutions**:
```python
# Check CUDA availability
import torch
print(torch.cuda.is_available())
print(torch.version.cuda)
# Install PyTorch with CUDA
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
# If CUDA works but SAM doesn't use it
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to("cuda") # Explicitly move to GPU
```
### Import errors
**Error**: `ModuleNotFoundError: No module named 'segment_anything'`
**Solutions**:
```bash
# Install from GitHub
pip install git+https://github.com/facebookresearch/segment-anything.git
# Or clone and install
git clone https://github.com/facebookresearch/segment-anything.git
cd segment-anything
pip install -e .
# Verify installation
python -c "from segment_anything import sam_model_registry; print('OK')"
```
### Missing dependencies
**Error**: `ModuleNotFoundError: No module named 'cv2'` or similar
**Solutions**:
```bash
# Install all optional dependencies
pip install opencv-python pycocotools matplotlib onnxruntime onnx
# For pycocotools on Windows
pip install pycocotools-windows
```
## Model Loading Issues
### Checkpoint not found
**Error**: `FileNotFoundError: checkpoint file not found`
**Solutions**:
```bash
# Download correct checkpoint
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
# Verify file integrity
md5sum sam_vit_h_4b8939.pth
# Expected: a7bf3b02f3ebf1267aba913ff637d9a2
# Use absolute path
sam = sam_model_registry["vit_h"](checkpoint="/full/path/to/sam_vit_h_4b8939.pth")
```
### Model type mismatch
**Error**: `KeyError: 'unexpected key in state_dict'`
**Solutions**:
```python
# Ensure model type matches checkpoint
# vit_h checkpoint → vit_h model
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
# vit_l checkpoint → vit_l model
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
# vit_b checkpoint → vit_b model
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
```
### Out of memory during load
**Error**: `CUDA out of memory` during model loading
**Solutions**:
```python
# Use smaller model
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
# Load to CPU first, then move
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to("cpu")
torch.cuda.empty_cache()
sam.to("cuda")
# Use half precision
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam = sam.half()
sam.to("cuda")
```
## Inference Issues
### Image format errors
**Error**: `ValueError: expected input to have 3 channels`
**Solutions**:
```python
import cv2
# Ensure RGB format
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR to RGB
# Convert grayscale to RGB
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# Handle RGBA
if image.shape[2] == 4:
image = image[:, :, :3] # Drop alpha channel
```
### Coordinate errors
**Error**: `IndexError: index out of bounds` or incorrect mask location
**Solutions**:
```python
# Ensure points are (x, y) not (row, col)
# x = column index, y = row index
point = np.array([[x, y]]) # Correct
# Verify coordinates are within image bounds
h, w = image.shape[:2]
assert 0 <= x < w and 0 <= y < h, "Point outside image"
# For bounding boxes: [x1, y1, x2, y2]
box = np.array([x1, y1, x2, y2])
assert x1 < x2 and y1 < y2, "Invalid box coordinates"
```
### Empty or incorrect masks
**Problem**: Masks don't match expected object
**Solutions**:
```python
# Try multiple prompts
input_points = np.array([[x1, y1], [x2, y2]])
input_labels = np.array([1, 1]) # Multiple foreground points
# Add background points
input_points = np.array([[obj_x, obj_y], [bg_x, bg_y]])
input_labels = np.array([1, 0]) # 1=foreground, 0=background
# Use box prompt for large objects
box = np.array([x1, y1, x2, y2])
masks, scores, _ = predictor.predict(box=box, multimask_output=False)
# Combine box and point
masks, scores, _ = predictor.predict(
point_coords=np.array([[center_x, center_y]]),
point_labels=np.array([1]),
box=np.array([x1, y1, x2, y2]),
multimask_output=True
)
# Check scores and select best
print(f"Scores: {scores}")
best_mask = masks[np.argmax(scores)]
```
### Slow inference
**Problem**: Prediction takes too long
**Solutions**:
```python
# Use smaller model
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
# Reuse image embeddings
predictor.set_image(image) # Compute once
for point in points:
masks, _, _ = predictor.predict(...) # Fast, reuses embeddings
# Reduce automatic generation points
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16, # Default is 32
)
# Use ONNX for deployment
# Export: python scripts/export_onnx_model.py --return-single-mask
```
## Automatic Mask Generation Issues
### Too many masks
**Problem**: Generating thousands of overlapping masks
**Solutions**:
```python
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16, # Reduce from 32
pred_iou_thresh=0.92, # Increase from 0.88
stability_score_thresh=0.98, # Increase from 0.95
box_nms_thresh=0.5, # More aggressive NMS
min_mask_region_area=500, # Remove small masks
)
```
### Too few masks
**Problem**: Missing objects in automatic generation
**Solutions**:
```python
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=64, # Increase density
pred_iou_thresh=0.80, # Lower threshold
stability_score_thresh=0.85, # Lower threshold
crop_n_layers=2, # Add multi-scale
min_mask_region_area=0, # Keep all masks
)
```
### Small objects missed
**Problem**: Automatic generation misses small objects
**Solutions**:
```python
# Use crop layers for multi-scale detection
mask_generator = SamAutomaticMaskGenerator(
model=sam,
crop_n_layers=2,
crop_n_points_downscale_factor=1, # Don't reduce points in crops
min_mask_region_area=10, # Very small minimum
)
# Or process image patches
def segment_with_patches(image, patch_size=512, overlap=64):
h, w = image.shape[:2]
all_masks = []
for y in range(0, h, patch_size - overlap):
for x in range(0, w, patch_size - overlap):
patch = image[y:y+patch_size, x:x+patch_size]
masks = mask_generator.generate(patch)
# Offset masks to original coordinates
for m in masks:
m['bbox'][0] += x
m['bbox'][1] += y
# Offset segmentation mask too
all_masks.extend(masks)
return all_masks
```
## Memory Issues
### CUDA out of memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
```python
# Use smaller model
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
# Clear cache between images
torch.cuda.empty_cache()
# Process images sequentially, not batched
for image in images:
predictor.set_image(image)
masks, _, _ = predictor.predict(...)
torch.cuda.empty_cache()
# Reduce image size
max_size = 1024
h, w = image.shape[:2]
if max(h, w) > max_size:
scale = max_size / max(h, w)
image = cv2.resize(image, (int(w*scale), int(h*scale)))
# Use CPU for large batch processing
sam.to("cpu")
```
### RAM out of memory
**Problem**: System runs out of RAM
**Solutions**:
```python
# Process images one at a time
for img_path in image_paths:
image = cv2.imread(img_path)
masks = process_image(image)
save_results(masks)
del image, masks
gc.collect()
# Use generators instead of lists
def generate_masks_lazy(image_paths):
for path in image_paths:
image = cv2.imread(path)
masks = mask_generator.generate(image)
yield path, masks
```
## ONNX Export Issues
### Export fails
**Error**: Various export errors
**Solutions**:
```bash
# Install correct ONNX version
pip install onnx==1.14.0 onnxruntime==1.15.0
# Use correct opset version
python scripts/export_onnx_model.py \
--checkpoint sam_vit_h_4b8939.pth \
--model-type vit_h \
--output sam.onnx \
--opset 17
```
### ONNX runtime errors
**Error**: `ONNXRuntimeError` during inference
**Solutions**:
```python
import onnxruntime
# Check available providers
print(onnxruntime.get_available_providers())
# Use CPU provider if GPU fails
session = onnxruntime.InferenceSession(
"sam.onnx",
providers=['CPUExecutionProvider']
)
# Verify input shapes
for input in session.get_inputs():
print(f"{input.name}: {input.shape}")
```
## HuggingFace Integration Issues
### Processor errors
**Error**: Issues with SamProcessor
**Solutions**:
```python
from transformers import SamModel, SamProcessor
# Use matching processor and model
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
# Ensure input format
input_points = [[[x, y]]] # Nested list for batch dimension
inputs = processor(image, input_points=input_points, return_tensors="pt")
# Post-process correctly
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
```
## Quality Issues
### Jagged mask edges
**Problem**: Masks have rough, pixelated edges
**Solutions**:
```python
import cv2
from scipy import ndimage
def smooth_mask(mask, sigma=2):
"""Smooth mask edges."""
# Gaussian blur
smooth = ndimage.gaussian_filter(mask.astype(float), sigma=sigma)
return smooth > 0.5
def refine_edges(mask, kernel_size=5):
"""Refine mask edges with morphological operations."""
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
# Close small gaps
closed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
# Open to remove noise
opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel)
return opened.astype(bool)
```
### Incomplete segmentation
**Problem**: Mask doesn't cover entire object
**Solutions**:
```python
# Add multiple points
input_points = np.array([
[obj_center_x, obj_center_y],
[obj_left_x, obj_center_y],
[obj_right_x, obj_center_y],
[obj_center_x, obj_top_y],
[obj_center_x, obj_bottom_y]
])
input_labels = np.array([1, 1, 1, 1, 1])
# Use bounding box
masks, _, _ = predictor.predict(
box=np.array([x1, y1, x2, y2]),
multimask_output=False
)
# Iterative refinement
mask_input = None
for point in points:
masks, scores, logits = predictor.predict(
point_coords=point.reshape(1, 2),
point_labels=np.array([1]),
mask_input=mask_input,
multimask_output=False
)
mask_input = logits
```
## Common Error Messages
| Error | Cause | Solution |
|-------|-------|----------|
| `CUDA out of memory` | GPU memory full | Use smaller model, clear cache |
| `expected 3 channels` | Wrong image format | Convert to RGB |
| `index out of bounds` | Invalid coordinates | Check point/box bounds |
| `checkpoint not found` | Wrong path | Use absolute path |
| `unexpected key` | Model/checkpoint mismatch | Match model type |
| `invalid box coordinates` | x1 > x2 or y1 > y2 | Fix box format |
## Getting Help
1. **GitHub Issues**: https://github.com/facebookresearch/segment-anything/issues
2. **HuggingFace Forums**: https://discuss.huggingface.co
3. **Paper**: https://arxiv.org/abs/2304.02643
### Reporting Issues
Include:
- Python version
- PyTorch version: `python -c "import torch; print(torch.__version__)"`
- CUDA version: `python -c "import torch; print(torch.version.cuda)"`
- SAM model type (vit_b/l/h)
- Full error traceback
- Minimal reproducible code

View file

@ -0,0 +1,522 @@
---
name: stable-diffusion-image-generation
description: State-of-the-art text-to-image generation with Stable Diffusion models via HuggingFace Diffusers. Use when generating images from text prompts, performing image-to-image translation, inpainting, or building custom diffusion pipelines.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [diffusers>=0.30.0, transformers>=4.41.0, accelerate>=0.31.0, torch>=2.0.0]
metadata:
hermes:
tags: [Image Generation, Stable Diffusion, Diffusers, Text-to-Image, Multimodal, Computer Vision]
---
# Stable Diffusion Image Generation
Comprehensive guide to generating images with Stable Diffusion using the HuggingFace Diffusers library.
## When to use Stable Diffusion
**Use Stable Diffusion when:**
- Generating images from text descriptions
- Performing image-to-image translation (style transfer, enhancement)
- Inpainting (filling in masked regions)
- Outpainting (extending images beyond boundaries)
- Creating variations of existing images
- Building custom image generation workflows
**Key features:**
- **Text-to-Image**: Generate images from natural language prompts
- **Image-to-Image**: Transform existing images with text guidance
- **Inpainting**: Fill masked regions with context-aware content
- **ControlNet**: Add spatial conditioning (edges, poses, depth)
- **LoRA Support**: Efficient fine-tuning and style adaptation
- **Multiple Models**: SD 1.5, SDXL, SD 3.0, Flux support
**Use alternatives instead:**
- **DALL-E 3**: For API-based generation without GPU
- **Midjourney**: For artistic, stylized outputs
- **Imagen**: For Google Cloud integration
- **Leonardo.ai**: For web-based creative workflows
## Quick start
### Installation
```bash
pip install diffusers transformers accelerate torch
pip install xformers # Optional: memory-efficient attention
```
### Basic text-to-image
```python
from diffusers import DiffusionPipeline
import torch
# Load pipeline (auto-detects model type)
pipe = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16
)
pipe.to("cuda")
# Generate image
image = pipe(
"A serene mountain landscape at sunset, highly detailed",
num_inference_steps=50,
guidance_scale=7.5
).images[0]
image.save("output.png")
```
### Using SDXL (higher quality)
```python
from diffusers import AutoPipelineForText2Image
import torch
pipe = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16"
)
pipe.to("cuda")
# Enable memory optimization
pipe.enable_model_cpu_offload()
image = pipe(
prompt="A futuristic city with flying cars, cinematic lighting",
height=1024,
width=1024,
num_inference_steps=30
).images[0]
```
## Architecture overview
### Three-pillar design
Diffusers is built around three core components:
```
Pipeline (orchestration)
├── Model (neural networks)
│ ├── UNet / Transformer (noise prediction)
│ ├── VAE (latent encoding/decoding)
│ └── Text Encoder (CLIP/T5)
└── Scheduler (denoising algorithm)
```
### Pipeline inference flow
```
Text Prompt → Text Encoder → Text Embeddings
Random Noise → [Denoising Loop] ← Scheduler
Predicted Noise
VAE Decoder → Final Image
```
## Core concepts
### Pipelines
Pipelines orchestrate complete workflows:
| Pipeline | Purpose |
|----------|---------|
| `StableDiffusionPipeline` | Text-to-image (SD 1.x/2.x) |
| `StableDiffusionXLPipeline` | Text-to-image (SDXL) |
| `StableDiffusion3Pipeline` | Text-to-image (SD 3.0) |
| `FluxPipeline` | Text-to-image (Flux models) |
| `StableDiffusionImg2ImgPipeline` | Image-to-image |
| `StableDiffusionInpaintPipeline` | Inpainting |
### Schedulers
Schedulers control the denoising process:
| Scheduler | Steps | Quality | Use Case |
|-----------|-------|---------|----------|
| `EulerDiscreteScheduler` | 20-50 | Good | Default choice |
| `EulerAncestralDiscreteScheduler` | 20-50 | Good | More variation |
| `DPMSolverMultistepScheduler` | 15-25 | Excellent | Fast, high quality |
| `DDIMScheduler` | 50-100 | Good | Deterministic |
| `LCMScheduler` | 4-8 | Good | Very fast |
| `UniPCMultistepScheduler` | 15-25 | Excellent | Fast convergence |
### Swapping schedulers
```python
from diffusers import DPMSolverMultistepScheduler
# Swap for faster generation
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config
)
# Now generate with fewer steps
image = pipe(prompt, num_inference_steps=20).images[0]
```
## Generation parameters
### Key parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `prompt` | Required | Text description of desired image |
| `negative_prompt` | None | What to avoid in the image |
| `num_inference_steps` | 50 | Denoising steps (more = better quality) |
| `guidance_scale` | 7.5 | Prompt adherence (7-12 typical) |
| `height`, `width` | 512/1024 | Output dimensions (multiples of 8) |
| `generator` | None | Torch generator for reproducibility |
| `num_images_per_prompt` | 1 | Batch size |
### Reproducible generation
```python
import torch
generator = torch.Generator(device="cuda").manual_seed(42)
image = pipe(
prompt="A cat wearing a top hat",
generator=generator,
num_inference_steps=50
).images[0]
```
### Negative prompts
```python
image = pipe(
prompt="Professional photo of a dog in a garden",
negative_prompt="blurry, low quality, distorted, ugly, bad anatomy",
guidance_scale=7.5
).images[0]
```
## Image-to-image
Transform existing images with text guidance:
```python
from diffusers import AutoPipelineForImage2Image
from PIL import Image
pipe = AutoPipelineForImage2Image.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16
).to("cuda")
init_image = Image.open("input.jpg").resize((512, 512))
image = pipe(
prompt="A watercolor painting of the scene",
image=init_image,
strength=0.75, # How much to transform (0-1)
num_inference_steps=50
).images[0]
```
## Inpainting
Fill masked regions:
```python
from diffusers import AutoPipelineForInpainting
from PIL import Image
pipe = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting",
torch_dtype=torch.float16
).to("cuda")
image = Image.open("photo.jpg")
mask = Image.open("mask.png") # White = inpaint region
result = pipe(
prompt="A red car parked on the street",
image=image,
mask_image=mask,
num_inference_steps=50
).images[0]
```
## ControlNet
Add spatial conditioning for precise control:
```python
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch
# Load ControlNet for edge conditioning
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny",
torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
controlnet=controlnet,
torch_dtype=torch.float16
).to("cuda")
# Use Canny edge image as control
control_image = get_canny_image(input_image)
image = pipe(
prompt="A beautiful house in the style of Van Gogh",
image=control_image,
num_inference_steps=30
).images[0]
```
### Available ControlNets
| ControlNet | Input Type | Use Case |
|------------|------------|----------|
| `canny` | Edge maps | Preserve structure |
| `openpose` | Pose skeletons | Human poses |
| `depth` | Depth maps | 3D-aware generation |
| `normal` | Normal maps | Surface details |
| `mlsd` | Line segments | Architectural lines |
| `scribble` | Rough sketches | Sketch-to-image |
## LoRA adapters
Load fine-tuned style adapters:
```python
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16
).to("cuda")
# Load LoRA weights
pipe.load_lora_weights("path/to/lora", weight_name="style.safetensors")
# Generate with LoRA style
image = pipe("A portrait in the trained style").images[0]
# Adjust LoRA strength
pipe.fuse_lora(lora_scale=0.8)
# Unload LoRA
pipe.unload_lora_weights()
```
### Multiple LoRAs
```python
# Load multiple LoRAs
pipe.load_lora_weights("lora1", adapter_name="style")
pipe.load_lora_weights("lora2", adapter_name="character")
# Set weights for each
pipe.set_adapters(["style", "character"], adapter_weights=[0.7, 0.5])
image = pipe("A portrait").images[0]
```
## Memory optimization
### Enable CPU offloading
```python
# Model CPU offload - moves models to CPU when not in use
pipe.enable_model_cpu_offload()
# Sequential CPU offload - more aggressive, slower
pipe.enable_sequential_cpu_offload()
```
### Attention slicing
```python
# Reduce memory by computing attention in chunks
pipe.enable_attention_slicing()
# Or specific chunk size
pipe.enable_attention_slicing("max")
```
### xFormers memory-efficient attention
```python
# Requires xformers package
pipe.enable_xformers_memory_efficient_attention()
```
### VAE slicing for large images
```python
# Decode latents in tiles for large images
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
```
## Model variants
### Loading different precisions
```python
# FP16 (recommended for GPU)
pipe = DiffusionPipeline.from_pretrained(
"model-id",
torch_dtype=torch.float16,
variant="fp16"
)
# BF16 (better precision, requires Ampere+ GPU)
pipe = DiffusionPipeline.from_pretrained(
"model-id",
torch_dtype=torch.bfloat16
)
```
### Loading specific components
```python
from diffusers import UNet2DConditionModel, AutoencoderKL
# Load custom VAE
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
# Use with pipeline
pipe = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
vae=vae,
torch_dtype=torch.float16
)
```
## Batch generation
Generate multiple images efficiently:
```python
# Multiple prompts
prompts = [
"A cat playing piano",
"A dog reading a book",
"A bird painting a picture"
]
images = pipe(prompts, num_inference_steps=30).images
# Multiple images per prompt
images = pipe(
"A beautiful sunset",
num_images_per_prompt=4,
num_inference_steps=30
).images
```
## Common workflows
### Workflow 1: High-quality generation
```python
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
import torch
# 1. Load SDXL with optimizations
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16"
)
pipe.to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
# 2. Generate with quality settings
image = pipe(
prompt="A majestic lion in the savanna, golden hour lighting, 8k, detailed fur",
negative_prompt="blurry, low quality, cartoon, anime, sketch",
num_inference_steps=30,
guidance_scale=7.5,
height=1024,
width=1024
).images[0]
```
### Workflow 2: Fast prototyping
```python
from diffusers import AutoPipelineForText2Image, LCMScheduler
import torch
# Use LCM for 4-8 step generation
pipe = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
# Load LCM LoRA for fast generation
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.fuse_lora()
# Generate in ~1 second
image = pipe(
"A beautiful landscape",
num_inference_steps=4,
guidance_scale=1.0
).images[0]
```
## Common issues
**CUDA out of memory:**
```python
# Enable memory optimizations
pipe.enable_model_cpu_offload()
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
# Or use lower precision
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
```
**Black/noise images:**
```python
# Check VAE configuration
# Use safety checker bypass if needed
pipe.safety_checker = None
# Ensure proper dtype consistency
pipe = pipe.to(dtype=torch.float16)
```
**Slow generation:**
```python
# Use faster scheduler
from diffusers import DPMSolverMultistepScheduler
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
# Reduce steps
image = pipe(prompt, num_inference_steps=20).images[0]
```
## References
- **[Advanced Usage](references/advanced-usage.md)** - Custom pipelines, fine-tuning, deployment
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
## Resources
- **Documentation**: https://huggingface.co/docs/diffusers
- **Repository**: https://github.com/huggingface/diffusers
- **Model Hub**: https://huggingface.co/models?library=diffusers
- **Discord**: https://discord.gg/diffusers

View file

@ -0,0 +1,716 @@
# Stable Diffusion Advanced Usage Guide
## Custom Pipelines
### Building from components
```python
from diffusers import (
UNet2DConditionModel,
AutoencoderKL,
DDPMScheduler,
StableDiffusionPipeline
)
from transformers import CLIPTextModel, CLIPTokenizer
import torch
# Load components individually
unet = UNet2DConditionModel.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="unet"
)
vae = AutoencoderKL.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="vae"
)
text_encoder = CLIPTextModel.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="text_encoder"
)
tokenizer = CLIPTokenizer.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="tokenizer"
)
scheduler = DDPMScheduler.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="scheduler"
)
# Assemble pipeline
pipe = StableDiffusionPipeline(
unet=unet,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False
)
```
### Custom denoising loop
```python
from diffusers import DDIMScheduler, AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
import torch
def custom_generate(
prompt: str,
num_steps: int = 50,
guidance_scale: float = 7.5,
height: int = 512,
width: int = 512
):
# Load components
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
unet = UNet2DConditionModel.from_pretrained("sd-model", subfolder="unet")
vae = AutoencoderKL.from_pretrained("sd-model", subfolder="vae")
scheduler = DDIMScheduler.from_pretrained("sd-model", subfolder="scheduler")
device = "cuda"
text_encoder.to(device)
unet.to(device)
vae.to(device)
# Encode prompt
text_input = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
)
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
# Unconditional embeddings for classifier-free guidance
uncond_input = tokenizer(
"",
padding="max_length",
max_length=77,
return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
# Concatenate for batch processing
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Initialize latents
latents = torch.randn(
(1, 4, height // 8, width // 8),
device=device
)
latents = latents * scheduler.init_noise_sigma
# Denoising loop
scheduler.set_timesteps(num_steps)
for t in scheduler.timesteps:
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# Predict noise
with torch.no_grad():
noise_pred = unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings
).sample
# Classifier-free guidance
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# Update latents
latents = scheduler.step(noise_pred, t, latents).prev_sample
# Decode latents
latents = latents / vae.config.scaling_factor
with torch.no_grad():
image = vae.decode(latents).sample
# Convert to PIL
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = (image * 255).round().astype("uint8")[0]
return Image.fromarray(image)
```
## IP-Adapter
Use image prompts alongside text:
```python
from diffusers import StableDiffusionPipeline
from diffusers.utils import load_image
import torch
pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16
).to("cuda")
# Load IP-Adapter
pipe.load_ip_adapter(
"h94/IP-Adapter",
subfolder="models",
weight_name="ip-adapter_sd15.bin"
)
# Set IP-Adapter scale
pipe.set_ip_adapter_scale(0.6)
# Load reference image
ip_image = load_image("reference_style.jpg")
# Generate with image + text prompt
image = pipe(
prompt="A portrait in a garden",
ip_adapter_image=ip_image,
num_inference_steps=50
).images[0]
```
### Multiple IP-Adapter images
```python
# Use multiple reference images
pipe.set_ip_adapter_scale([0.5, 0.7])
images = [
load_image("style_reference.jpg"),
load_image("composition_reference.jpg")
]
result = pipe(
prompt="A landscape painting",
ip_adapter_image=images,
num_inference_steps=50
).images[0]
```
## SDXL Refiner
Two-stage generation for higher quality:
```python
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
import torch
# Load base model
base = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
# Load refiner
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
# Generate with base (partial denoising)
image = base(
prompt="A majestic eagle soaring over mountains",
num_inference_steps=40,
denoising_end=0.8,
output_type="latent"
).images
# Refine with refiner
refined = refiner(
prompt="A majestic eagle soaring over mountains",
image=image,
num_inference_steps=40,
denoising_start=0.8
).images[0]
```
## T2I-Adapter
Lightweight conditioning without full ControlNet:
```python
from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter
import torch
# Load adapter
adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-canny-sdxl-1.0",
torch_dtype=torch.float16
)
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
adapter=adapter,
torch_dtype=torch.float16
).to("cuda")
# Get canny edges
canny_image = get_canny_image(input_image)
image = pipe(
prompt="A colorful anime character",
image=canny_image,
num_inference_steps=30,
adapter_conditioning_scale=0.8
).images[0]
```
## Fine-tuning with DreamBooth
Train on custom subjects:
```python
from diffusers import StableDiffusionPipeline, DDPMScheduler
from diffusers.optimization import get_scheduler
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
class DreamBoothDataset(Dataset):
def __init__(self, instance_images_path, instance_prompt, tokenizer, size=512):
self.instance_images_path = instance_images_path
self.instance_prompt = instance_prompt
self.tokenizer = tokenizer
self.size = size
self.instance_images = [
os.path.join(instance_images_path, f)
for f in os.listdir(instance_images_path)
if f.endswith(('.png', '.jpg', '.jpeg'))
]
def __len__(self):
return len(self.instance_images)
def __getitem__(self, idx):
image = Image.open(self.instance_images[idx]).convert("RGB")
image = image.resize((self.size, self.size))
image = torch.tensor(np.array(image)).permute(2, 0, 1) / 127.5 - 1.0
tokens = self.tokenizer(
self.instance_prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
)
return {"image": image, "input_ids": tokens.input_ids.squeeze()}
def train_dreambooth(
pretrained_model: str,
instance_data_dir: str,
instance_prompt: str,
output_dir: str,
learning_rate: float = 5e-6,
max_train_steps: int = 800,
train_batch_size: int = 1
):
# Load pipeline
pipe = StableDiffusionPipeline.from_pretrained(pretrained_model)
unet = pipe.unet
vae = pipe.vae
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model, subfolder="scheduler")
# Freeze VAE and text encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# Create dataset
dataset = DreamBoothDataset(
instance_data_dir, instance_prompt, tokenizer
)
dataloader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
# Setup optimizer
optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
"constant",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=max_train_steps
)
# Training loop
unet.train()
device = "cuda"
unet.to(device)
vae.to(device)
text_encoder.to(device)
global_step = 0
for epoch in range(max_train_steps // len(dataloader) + 1):
for batch in dataloader:
if global_step >= max_train_steps:
break
# Encode images to latents
latents = vae.encode(batch["image"].to(device)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise
noise = torch.randn_like(latents)
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],))
timesteps = timesteps.to(device)
# Add noise
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get text embeddings
encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0]
# Predict noise
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Compute loss
loss = torch.nn.functional.mse_loss(noise_pred, noise)
# Backprop
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
global_step += 1
if global_step % 100 == 0:
print(f"Step {global_step}, Loss: {loss.item():.4f}")
# Save model
pipe.unet = unet
pipe.save_pretrained(output_dir)
```
## LoRA Training
Efficient fine-tuning with Low-Rank Adaptation:
```python
from peft import LoraConfig, get_peft_model
from diffusers import StableDiffusionPipeline
import torch
def train_lora(
base_model: str,
train_dataset,
output_dir: str,
lora_rank: int = 4,
learning_rate: float = 1e-4,
max_train_steps: int = 1000
):
pipe = StableDiffusionPipeline.from_pretrained(base_model)
unet = pipe.unet
# Configure LoRA
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_rank,
target_modules=["to_q", "to_v", "to_k", "to_out.0"],
lora_dropout=0.1
)
# Apply LoRA to UNet
unet = get_peft_model(unet, lora_config)
unet.print_trainable_parameters() # Shows ~0.1% trainable
# Train (similar to DreamBooth but only LoRA params)
optimizer = torch.optim.AdamW(
unet.parameters(),
lr=learning_rate
)
# ... training loop ...
# Save LoRA weights only
unet.save_pretrained(output_dir)
```
## Textual Inversion
Learn new concepts through embeddings:
```python
from diffusers import StableDiffusionPipeline
import torch
# Load with textual inversion
pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16
).to("cuda")
# Load learned embedding
pipe.load_textual_inversion(
"sd-concepts-library/cat-toy",
token="<cat-toy>"
)
# Use in prompts
image = pipe("A photo of <cat-toy> on a beach").images[0]
```
## Quantization
Reduce memory with quantization:
```python
from diffusers import BitsAndBytesConfig, StableDiffusionXLPipeline
import torch
# 8-bit quantization
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
quantization_config=quantization_config,
torch_dtype=torch.float16
)
```
### NF4 quantization (4-bit)
```python
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
quantization_config=quantization_config
)
```
## Production Deployment
### FastAPI server
```python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from diffusers import DiffusionPipeline
import torch
import base64
from io import BytesIO
app = FastAPI()
# Load model at startup
pipe = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16
).to("cuda")
pipe.enable_model_cpu_offload()
class GenerationRequest(BaseModel):
prompt: str
negative_prompt: str = ""
num_inference_steps: int = 30
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = None
class GenerationResponse(BaseModel):
image_base64: str
seed: int
@app.post("/generate", response_model=GenerationResponse)
async def generate(request: GenerationRequest):
try:
generator = None
seed = request.seed or torch.randint(0, 2**32, (1,)).item()
generator = torch.Generator("cuda").manual_seed(seed)
image = pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
width=request.width,
height=request.height,
generator=generator
).images[0]
# Convert to base64
buffer = BytesIO()
image.save(buffer, format="PNG")
image_base64 = base64.b64encode(buffer.getvalue()).decode()
return GenerationResponse(image_base64=image_base64, seed=seed)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {"status": "healthy"}
```
### Docker deployment
```dockerfile
FROM nvidia/cuda:12.1-runtime-ubuntu22.04
RUN apt-get update && apt-get install -y python3 python3-pip
WORKDIR /app
COPY requirements.txt .
RUN pip3 install -r requirements.txt
COPY . .
# Pre-download model
RUN python3 -c "from diffusers import DiffusionPipeline; DiffusionPipeline.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5')"
EXPOSE 8000
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
```
### Kubernetes deployment
```yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: stable-diffusion
spec:
replicas: 2
selector:
matchLabels:
app: stable-diffusion
template:
metadata:
labels:
app: stable-diffusion
spec:
containers:
- name: sd
image: your-registry/stable-diffusion:latest
ports:
- containerPort: 8000
resources:
limits:
nvidia.com/gpu: 1
memory: "16Gi"
requests:
nvidia.com/gpu: 1
memory: "8Gi"
env:
- name: TRANSFORMERS_CACHE
value: "/cache/huggingface"
volumeMounts:
- name: model-cache
mountPath: /cache
volumes:
- name: model-cache
persistentVolumeClaim:
claimName: model-cache-pvc
---
apiVersion: v1
kind: Service
metadata:
name: stable-diffusion
spec:
selector:
app: stable-diffusion
ports:
- port: 80
targetPort: 8000
type: LoadBalancer
```
## Callback System
Monitor and modify generation:
```python
from diffusers import StableDiffusionPipeline
from diffusers.callbacks import PipelineCallback
import torch
class ProgressCallback(PipelineCallback):
def __init__(self):
self.progress = []
def callback_fn(self, pipe, step_index, timestep, callback_kwargs):
self.progress.append({
"step": step_index,
"timestep": timestep.item()
})
# Optionally modify latents
latents = callback_kwargs["latents"]
return callback_kwargs
# Use callback
callback = ProgressCallback()
image = pipe(
prompt="A sunset",
callback_on_step_end=callback.callback_fn,
callback_on_step_end_tensor_inputs=["latents"]
).images[0]
print(f"Generation completed in {len(callback.progress)} steps")
```
### Early stopping
```python
def early_stop_callback(pipe, step_index, timestep, callback_kwargs):
# Stop after 20 steps
if step_index >= 20:
pipe._interrupt = True
return callback_kwargs
image = pipe(
prompt="A landscape",
num_inference_steps=50,
callback_on_step_end=early_stop_callback
).images[0]
```
## Multi-GPU Inference
### Device map auto
```python
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
device_map="auto", # Automatically distribute across GPUs
torch_dtype=torch.float16
)
```
### Manual distribution
```python
from accelerate import infer_auto_device_map, dispatch_model
# Create device map
device_map = infer_auto_device_map(
pipe.unet,
max_memory={0: "10GiB", 1: "10GiB"}
)
# Dispatch model
pipe.unet = dispatch_model(pipe.unet, device_map=device_map)
```

View file

@ -0,0 +1,555 @@
# Stable Diffusion Troubleshooting Guide
## Installation Issues
### Package conflicts
**Error**: `ImportError: cannot import name 'cached_download' from 'huggingface_hub'`
**Fix**:
```bash
# Update huggingface_hub
pip install --upgrade huggingface_hub
# Reinstall diffusers
pip install --upgrade diffusers
```
### xFormers installation fails
**Error**: `RuntimeError: CUDA error: no kernel image is available for execution`
**Fix**:
```bash
# Check CUDA version
nvcc --version
# Install matching xformers
pip install xformers --index-url https://download.pytorch.org/whl/cu121 # For CUDA 12.1
# Or build from source
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
```
### Torch/CUDA mismatch
**Error**: `RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED`
**Fix**:
```bash
# Check versions
python -c "import torch; print(torch.__version__, torch.cuda.is_available())"
# Reinstall PyTorch with correct CUDA
pip uninstall torch torchvision
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
```
## Memory Issues
### CUDA out of memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
```python
# Solution 1: Enable CPU offloading
pipe.enable_model_cpu_offload()
# Solution 2: Sequential CPU offload (more aggressive)
pipe.enable_sequential_cpu_offload()
# Solution 3: Attention slicing
pipe.enable_attention_slicing()
# Solution 4: VAE slicing for large images
pipe.enable_vae_slicing()
# Solution 5: Use lower precision
pipe = DiffusionPipeline.from_pretrained(
"model-id",
torch_dtype=torch.float16 # or torch.bfloat16
)
# Solution 6: Reduce batch size
image = pipe(prompt, num_images_per_prompt=1).images[0]
# Solution 7: Generate smaller images
image = pipe(prompt, height=512, width=512).images[0]
# Solution 8: Clear cache between generations
import gc
torch.cuda.empty_cache()
gc.collect()
```
### Memory grows over time
**Problem**: Memory usage increases with each generation
**Fix**:
```python
import gc
import torch
def generate_with_cleanup(pipe, prompt, **kwargs):
try:
image = pipe(prompt, **kwargs).images[0]
return image
finally:
# Clear cache after generation
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
```
### Large model loading fails
**Error**: `RuntimeError: Unable to load model weights`
**Fix**:
```python
# Use low CPU memory mode
pipe = DiffusionPipeline.from_pretrained(
"large-model-id",
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
```
## Generation Issues
### Black images
**Problem**: Output images are completely black
**Solutions**:
```python
# Solution 1: Disable safety checker
pipe.safety_checker = None
# Solution 2: Check VAE scaling
# The issue might be with VAE encoding/decoding
latents = latents / pipe.vae.config.scaling_factor # Before decode
# Solution 3: Ensure proper dtype
pipe = pipe.to(dtype=torch.float16)
pipe.vae = pipe.vae.to(dtype=torch.float32) # VAE often needs fp32
# Solution 4: Check guidance scale
# Too high can cause issues
image = pipe(prompt, guidance_scale=7.5).images[0] # Not 20+
```
### Noise/static images
**Problem**: Output looks like random noise
**Solutions**:
```python
# Solution 1: Increase inference steps
image = pipe(prompt, num_inference_steps=50).images[0]
# Solution 2: Check scheduler configuration
pipe.scheduler = pipe.scheduler.from_config(pipe.scheduler.config)
# Solution 3: Verify model was loaded correctly
print(pipe.unet) # Should show model architecture
```
### Blurry images
**Problem**: Output images are low quality or blurry
**Solutions**:
```python
# Solution 1: Use more steps
image = pipe(prompt, num_inference_steps=50).images[0]
# Solution 2: Use better VAE
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
pipe.vae = vae
# Solution 3: Use SDXL or refiner
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0"
)
# Solution 4: Upscale with img2img
upscale_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(...)
upscaled = upscale_pipe(
prompt=prompt,
image=image.resize((1024, 1024)),
strength=0.3
).images[0]
```
### Prompt not being followed
**Problem**: Generated image doesn't match the prompt
**Solutions**:
```python
# Solution 1: Increase guidance scale
image = pipe(prompt, guidance_scale=10.0).images[0]
# Solution 2: Use negative prompts
image = pipe(
prompt="A red car",
negative_prompt="blue, green, yellow, wrong color",
guidance_scale=7.5
).images[0]
# Solution 3: Use prompt weighting
# Emphasize important words
prompt = "A (red:1.5) car on a street"
# Solution 4: Use longer, more detailed prompts
prompt = """
A bright red sports car, ferrari style, parked on a city street,
photorealistic, high detail, 8k, professional photography
"""
```
### Distorted faces/hands
**Problem**: Faces and hands look deformed
**Solutions**:
```python
# Solution 1: Use negative prompts
negative_prompt = """
bad hands, bad anatomy, deformed, ugly, blurry,
extra fingers, mutated hands, poorly drawn hands,
poorly drawn face, mutation, deformed face
"""
# Solution 2: Use face-specific models
# ADetailer or similar post-processing
# Solution 3: Use ControlNet for poses
# Load pose estimation and condition generation
# Solution 4: Inpaint problematic areas
mask = create_face_mask(image)
fixed = inpaint_pipe(
prompt="beautiful detailed face",
image=image,
mask_image=mask
).images[0]
```
## Scheduler Issues
### Scheduler not compatible
**Error**: `ValueError: Scheduler ... is not compatible with pipeline`
**Fix**:
```python
from diffusers import EulerDiscreteScheduler
# Create scheduler from config
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config
)
# Check compatible schedulers
print(pipe.scheduler.compatibles)
```
### Wrong number of steps
**Problem**: Model generates different quality with same steps
**Fix**:
```python
# Reset timesteps explicitly
pipe.scheduler.set_timesteps(num_inference_steps)
# Check scheduler's step count
print(len(pipe.scheduler.timesteps))
```
## LoRA Issues
### LoRA weights not loading
**Error**: `RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel`
**Fix**:
```python
# Check weight file format
# Should be .safetensors or .bin
# Load with correct key prefix
pipe.load_lora_weights(
"path/to/lora",
weight_name="lora.safetensors"
)
# Try loading into specific component
pipe.unet.load_attn_procs("path/to/lora")
```
### LoRA not affecting output
**Problem**: Generated images look the same with/without LoRA
**Fix**:
```python
# Fuse LoRA weights
pipe.fuse_lora(lora_scale=1.0)
# Or set scale explicitly
pipe.set_adapters(["lora_name"], adapter_weights=[1.0])
# Verify LoRA is loaded
print(list(pipe.unet.attn_processors.keys()))
```
### Multiple LoRAs conflict
**Problem**: Multiple LoRAs produce artifacts
**Fix**:
```python
# Load with different adapter names
pipe.load_lora_weights("lora1", adapter_name="style")
pipe.load_lora_weights("lora2", adapter_name="subject")
# Balance weights
pipe.set_adapters(
["style", "subject"],
adapter_weights=[0.5, 0.5] # Lower weights
)
# Or use LoRA merge before loading
# Merge LoRAs offline with appropriate ratios
```
## ControlNet Issues
### ControlNet not conditioning
**Problem**: ControlNet has no effect on output
**Fix**:
```python
# Check control image format
# Should be RGB, matching generation size
control_image = control_image.resize((512, 512))
# Increase conditioning scale
image = pipe(
prompt=prompt,
image=control_image,
controlnet_conditioning_scale=1.0, # Try 0.5-1.5
num_inference_steps=30
).images[0]
# Verify ControlNet is loaded
print(pipe.controlnet)
```
### Control image preprocessing
**Fix**:
```python
from controlnet_aux import CannyDetector
# Proper preprocessing
canny = CannyDetector()
control_image = canny(input_image)
# Ensure correct format
control_image = control_image.convert("RGB")
control_image = control_image.resize((512, 512))
```
## Hub/Download Issues
### Model download fails
**Error**: `requests.exceptions.ConnectionError`
**Fix**:
```bash
# Set longer timeout
export HF_HUB_DOWNLOAD_TIMEOUT=600
# Use mirror if available
export HF_ENDPOINT=https://hf-mirror.com
# Or download manually
huggingface-cli download stable-diffusion-v1-5/stable-diffusion-v1-5
```
### Cache issues
**Error**: `OSError: Can't load model from cache`
**Fix**:
```bash
# Clear cache
rm -rf ~/.cache/huggingface/hub
# Or set different cache location
export HF_HOME=/path/to/cache
# Force re-download
pipe = DiffusionPipeline.from_pretrained(
"model-id",
force_download=True
)
```
### Access denied for gated models
**Error**: `401 Client Error: Unauthorized`
**Fix**:
```bash
# Login to Hugging Face
huggingface-cli login
# Or use token
pipe = DiffusionPipeline.from_pretrained(
"model-id",
token="hf_xxxxx"
)
# Accept model license on Hub website first
```
## Performance Issues
### Slow generation
**Problem**: Generation takes too long
**Solutions**:
```python
# Solution 1: Use faster scheduler
from diffusers import DPMSolverMultistepScheduler
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config
)
# Solution 2: Reduce steps
image = pipe(prompt, num_inference_steps=20).images[0]
# Solution 3: Use LCM
from diffusers import LCMScheduler
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
image = pipe(prompt, num_inference_steps=4, guidance_scale=1.0).images[0]
# Solution 4: Enable xFormers
pipe.enable_xformers_memory_efficient_attention()
# Solution 5: Compile model
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
### First generation is slow
**Problem**: First image takes much longer
**Fix**:
```python
# Warm up the model
_ = pipe("warmup", num_inference_steps=1)
# Then run actual generation
image = pipe(prompt, num_inference_steps=50).images[0]
# Compile for faster subsequent runs
pipe.unet = torch.compile(pipe.unet)
```
## Debugging Tips
### Enable debug logging
```python
import logging
logging.basicConfig(level=logging.DEBUG)
# Or for specific modules
logging.getLogger("diffusers").setLevel(logging.DEBUG)
logging.getLogger("transformers").setLevel(logging.DEBUG)
```
### Check model components
```python
# Print pipeline components
print(pipe.components)
# Check model config
print(pipe.unet.config)
print(pipe.vae.config)
print(pipe.scheduler.config)
# Verify device placement
print(pipe.device)
for name, module in pipe.components.items():
if hasattr(module, 'device'):
print(f"{name}: {module.device}")
```
### Validate inputs
```python
# Check image dimensions
print(f"Height: {height}, Width: {width}")
assert height % 8 == 0, "Height must be divisible by 8"
assert width % 8 == 0, "Width must be divisible by 8"
# Check prompt tokenization
tokens = pipe.tokenizer(prompt, return_tensors="pt")
print(f"Token count: {tokens.input_ids.shape[1]}") # Max 77 for SD
```
### Save intermediate results
```python
def save_latents_callback(pipe, step_index, timestep, callback_kwargs):
latents = callback_kwargs["latents"]
# Decode and save intermediate
with torch.no_grad():
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
Image.fromarray((image * 255).astype("uint8")).save(f"step_{step_index}.png")
return callback_kwargs
image = pipe(
prompt,
callback_on_step_end=save_latents_callback,
callback_on_step_end_tensor_inputs=["latents"]
).images[0]
```
## Getting Help
1. **Documentation**: https://huggingface.co/docs/diffusers
2. **GitHub Issues**: https://github.com/huggingface/diffusers/issues
3. **Discord**: https://discord.gg/diffusers
4. **Forum**: https://discuss.huggingface.co
### Reporting Issues
Include:
- Diffusers version: `pip show diffusers`
- PyTorch version: `python -c "import torch; print(torch.__version__)"`
- CUDA version: `nvcc --version`
- GPU model: `nvidia-smi`
- Full error traceback
- Minimal reproducible code
- Model name/ID used

View file

@ -0,0 +1,320 @@
---
name: whisper
description: OpenAI's general-purpose speech recognition model. Supports 99 languages, transcription, translation to English, and language identification. Six model sizes from tiny (39M params) to large (1550M params). Use for speech-to-text, podcast transcription, or multilingual audio processing. Best for robust, multilingual ASR.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [openai-whisper, transformers, torch]
metadata:
hermes:
tags: [Whisper, Speech Recognition, ASR, Multimodal, Multilingual, OpenAI, Speech-To-Text, Transcription, Translation, Audio Processing]
---
# Whisper - Robust Speech Recognition
OpenAI's multilingual speech recognition model.
## When to use Whisper
**Use when:**
- Speech-to-text transcription (99 languages)
- Podcast/video transcription
- Meeting notes automation
- Translation to English
- Noisy audio transcription
- Multilingual audio processing
**Metrics**:
- **72,900+ GitHub stars**
- 99 languages supported
- Trained on 680,000 hours of audio
- MIT License
**Use alternatives instead**:
- **AssemblyAI**: Managed API, speaker diarization
- **Deepgram**: Real-time streaming ASR
- **Google Speech-to-Text**: Cloud-based
## Quick start
### Installation
```bash
# Requires Python 3.8-3.11
pip install -U openai-whisper
# Requires ffmpeg
# macOS: brew install ffmpeg
# Ubuntu: sudo apt install ffmpeg
# Windows: choco install ffmpeg
```
### Basic transcription
```python
import whisper
# Load model
model = whisper.load_model("base")
# Transcribe
result = model.transcribe("audio.mp3")
# Print text
print(result["text"])
# Access segments
for segment in result["segments"]:
print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] {segment['text']}")
```
## Model sizes
```python
# Available models
models = ["tiny", "base", "small", "medium", "large", "turbo"]
# Load specific model
model = whisper.load_model("turbo") # Fastest, good quality
```
| Model | Parameters | English-only | Multilingual | Speed | VRAM |
|-------|------------|--------------|--------------|-------|------|
| tiny | 39M | ✓ | ✓ | ~32x | ~1 GB |
| base | 74M | ✓ | ✓ | ~16x | ~1 GB |
| small | 244M | ✓ | ✓ | ~6x | ~2 GB |
| medium | 769M | ✓ | ✓ | ~2x | ~5 GB |
| large | 1550M | ✗ | ✓ | 1x | ~10 GB |
| turbo | 809M | ✗ | ✓ | ~8x | ~6 GB |
**Recommendation**: Use `turbo` for best speed/quality, `base` for prototyping
## Transcription options
### Language specification
```python
# Auto-detect language
result = model.transcribe("audio.mp3")
# Specify language (faster)
result = model.transcribe("audio.mp3", language="en")
# Supported: en, es, fr, de, it, pt, ru, ja, ko, zh, and 89 more
```
### Task selection
```python
# Transcription (default)
result = model.transcribe("audio.mp3", task="transcribe")
# Translation to English
result = model.transcribe("spanish.mp3", task="translate")
# Input: Spanish audio → Output: English text
```
### Initial prompt
```python
# Improve accuracy with context
result = model.transcribe(
"audio.mp3",
initial_prompt="This is a technical podcast about machine learning and AI."
)
# Helps with:
# - Technical terms
# - Proper nouns
# - Domain-specific vocabulary
```
### Timestamps
```python
# Word-level timestamps
result = model.transcribe("audio.mp3", word_timestamps=True)
for segment in result["segments"]:
for word in segment["words"]:
print(f"{word['word']} ({word['start']:.2f}s - {word['end']:.2f}s)")
```
### Temperature fallback
```python
# Retry with different temperatures if confidence low
result = model.transcribe(
"audio.mp3",
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
)
```
## Command line usage
```bash
# Basic transcription
whisper audio.mp3
# Specify model
whisper audio.mp3 --model turbo
# Output formats
whisper audio.mp3 --output_format txt # Plain text
whisper audio.mp3 --output_format srt # Subtitles
whisper audio.mp3 --output_format vtt # WebVTT
whisper audio.mp3 --output_format json # JSON with timestamps
# Language
whisper audio.mp3 --language Spanish
# Translation
whisper spanish.mp3 --task translate
```
## Batch processing
```python
import os
audio_files = ["file1.mp3", "file2.mp3", "file3.mp3"]
for audio_file in audio_files:
print(f"Transcribing {audio_file}...")
result = model.transcribe(audio_file)
# Save to file
output_file = audio_file.replace(".mp3", ".txt")
with open(output_file, "w") as f:
f.write(result["text"])
```
## Real-time transcription
```python
# For streaming audio, use faster-whisper
# pip install faster-whisper
from faster_whisper import WhisperModel
model = WhisperModel("base", device="cuda", compute_type="float16")
# Transcribe with streaming
segments, info = model.transcribe("audio.mp3", beam_size=5)
for segment in segments:
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
```
## GPU acceleration
```python
import whisper
# Automatically uses GPU if available
model = whisper.load_model("turbo")
# Force CPU
model = whisper.load_model("turbo", device="cpu")
# Force GPU
model = whisper.load_model("turbo", device="cuda")
# 10-20× faster on GPU
```
## Integration with other tools
### Subtitle generation
```bash
# Generate SRT subtitles
whisper video.mp4 --output_format srt --language English
# Output: video.srt
```
### With LangChain
```python
from langchain.document_loaders import WhisperTranscriptionLoader
loader = WhisperTranscriptionLoader(file_path="audio.mp3")
docs = loader.load()
# Use transcription in RAG
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
vectorstore = Chroma.from_documents(docs, OpenAIEmbeddings())
```
### Extract audio from video
```bash
# Use ffmpeg to extract audio
ffmpeg -i video.mp4 -vn -acodec pcm_s16le audio.wav
# Then transcribe
whisper audio.wav
```
## Best practices
1. **Use turbo model** - Best speed/quality for English
2. **Specify language** - Faster than auto-detect
3. **Add initial prompt** - Improves technical terms
4. **Use GPU** - 10-20× faster
5. **Batch process** - More efficient
6. **Convert to WAV** - Better compatibility
7. **Split long audio** - <30 min chunks
8. **Check language support** - Quality varies by language
9. **Use faster-whisper** - 4× faster than openai-whisper
10. **Monitor VRAM** - Scale model size to hardware
## Performance
| Model | Real-time factor (CPU) | Real-time factor (GPU) |
|-------|------------------------|------------------------|
| tiny | ~0.32 | ~0.01 |
| base | ~0.16 | ~0.01 |
| turbo | ~0.08 | ~0.01 |
| large | ~1.0 | ~0.05 |
*Real-time factor: 0.1 = 10× faster than real-time*
## Language support
Top-supported languages:
- English (en)
- Spanish (es)
- French (fr)
- German (de)
- Italian (it)
- Portuguese (pt)
- Russian (ru)
- Japanese (ja)
- Korean (ko)
- Chinese (zh)
Full list: 99 languages total
## Limitations
1. **Hallucinations** - May repeat or invent text
2. **Long-form accuracy** - Degrades on >30 min audio
3. **Speaker identification** - No diarization
4. **Accents** - Quality varies
5. **Background noise** - Can affect accuracy
6. **Real-time latency** - Not suitable for live captioning
## Resources
- **GitHub**: https://github.com/openai/whisper ⭐ 72,900+
- **Paper**: https://arxiv.org/abs/2212.04356
- **Model Card**: https://github.com/openai/whisper/blob/main/model-card.md
- **Colab**: Available in repo
- **License**: MIT

View file

@ -0,0 +1,189 @@
# Whisper Language Support Guide
Complete guide to Whisper's multilingual capabilities.
## Supported languages (99 total)
### Top-tier support (WER < 10%)
- English (en)
- Spanish (es)
- French (fr)
- German (de)
- Italian (it)
- Portuguese (pt)
- Dutch (nl)
- Polish (pl)
- Russian (ru)
- Japanese (ja)
- Korean (ko)
- Chinese (zh)
### Good support (WER 10-20%)
- Arabic (ar)
- Turkish (tr)
- Vietnamese (vi)
- Swedish (sv)
- Finnish (fi)
- Czech (cs)
- Romanian (ro)
- Hungarian (hu)
- Danish (da)
- Norwegian (no)
- Thai (th)
- Hebrew (he)
- Greek (el)
- Indonesian (id)
- Malay (ms)
### Full list (99 languages)
Afrikaans, Albanian, Amharic, Arabic, Armenian, Assamese, Azerbaijani, Bashkir, Basque, Belarusian, Bengali, Bosnian, Breton, Bulgarian, Burmese, Cantonese, Catalan, Chinese, Croatian, Czech, Danish, Dutch, English, Estonian, Faroese, Finnish, French, Galician, Georgian, German, Greek, Gujarati, Haitian Creole, Hausa, Hawaiian, Hebrew, Hindi, Hungarian, Icelandic, Indonesian, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Lao, Latin, Latvian, Lingala, Lithuanian, Luxembourgish, Macedonian, Malagasy, Malay, Malayalam, Maltese, Maori, Marathi, Moldavian, Mongolian, Myanmar, Nepali, Norwegian, Nynorsk, Occitan, Pashto, Persian, Polish, Portuguese, Punjabi, Pushto, Romanian, Russian, Sanskrit, Serbian, Shona, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tagalog, Tajik, Tamil, Tatar, Telugu, Thai, Tibetan, Turkish, Turkmen, Ukrainian, Urdu, Uzbek, Vietnamese, Welsh, Yiddish, Yoruba
## Usage examples
### Auto-detect language
```python
import whisper
model = whisper.load_model("turbo")
# Auto-detect language
result = model.transcribe("audio.mp3")
print(f"Detected language: {result['language']}")
print(f"Text: {result['text']}")
```
### Specify language (faster)
```python
# Specify language for faster transcription
result = model.transcribe("audio.mp3", language="es") # Spanish
result = model.transcribe("audio.mp3", language="fr") # French
result = model.transcribe("audio.mp3", language="ja") # Japanese
```
### Translation to English
```python
# Translate any language to English
result = model.transcribe(
"spanish_audio.mp3",
task="translate" # Translates to English
)
print(f"Original language: {result['language']}")
print(f"English translation: {result['text']}")
```
## Language-specific tips
### Chinese
```python
# Chinese works well with larger models
model = whisper.load_model("large")
result = model.transcribe(
"chinese_audio.mp3",
language="zh",
initial_prompt="这是一段关于技术的讨论" # Context helps
)
```
### Japanese
```python
# Japanese benefits from initial prompt
result = model.transcribe(
"japanese_audio.mp3",
language="ja",
initial_prompt="これは技術的な会議の録音です"
)
```
### Arabic
```python
# Arabic: Use large model for best results
model = whisper.load_model("large")
result = model.transcribe(
"arabic_audio.mp3",
language="ar"
)
```
## Model size recommendations
| Language Tier | Recommended Model | WER |
|---------------|-------------------|-----|
| Top-tier (en, es, fr, de) | base/turbo | < 10% |
| Good (ar, tr, vi) | medium/large | 10-20% |
| Lower-resource | large | 20-30% |
## Performance by language
### English
- **tiny**: WER ~15%
- **base**: WER ~8%
- **small**: WER ~5%
- **medium**: WER ~4%
- **large**: WER ~3%
- **turbo**: WER ~3.5%
### Spanish
- **tiny**: WER ~20%
- **base**: WER ~12%
- **medium**: WER ~6%
- **large**: WER ~4%
### Chinese
- **small**: WER ~15%
- **medium**: WER ~8%
- **large**: WER ~5%
## Best practices
1. **Use English-only models** - Better for small models (tiny/base)
2. **Specify language** - Faster than auto-detect
3. **Add initial prompt** - Improves accuracy for technical terms
4. **Use larger models** - For low-resource languages
5. **Test on sample** - Quality varies by accent/dialect
6. **Consider audio quality** - Clear audio = better results
7. **Check language codes** - Use ISO 639-1 codes (2 letters)
## Language detection
```python
# Detect language only (no transcription)
import whisper
model = whisper.load_model("base")
# Load audio
audio = whisper.load_audio("audio.mp3")
audio = whisper.pad_or_trim(audio)
# Make log-Mel spectrogram
mel = whisper.log_mel_spectrogram(audio).to(model.device)
# Detect language
_, probs = model.detect_language(mel)
detected_language = max(probs, key=probs.get)
print(f"Detected language: {detected_language}")
print(f"Confidence: {probs[detected_language]:.2%}")
```
## Resources
- **Paper**: https://arxiv.org/abs/2212.04356
- **GitHub**: https://github.com/openai/whisper
- **Model Card**: https://github.com/openai/whisper/blob/main/model-card.md

View file

@ -0,0 +1,3 @@
---
description: ML research frameworks for building and optimizing AI systems with declarative programming.
---

View file

@ -0,0 +1,593 @@
---
name: dspy
description: Build complex AI systems with declarative programming, optimize prompts automatically, create modular RAG systems and agents with DSPy - Stanford NLP's framework for systematic LM programming
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [dspy, openai, anthropic]
metadata:
hermes:
tags: [Prompt Engineering, DSPy, Declarative Programming, RAG, Agents, Prompt Optimization, LM Programming, Stanford NLP, Automatic Optimization, Modular AI]
---
# DSPy: Declarative Language Model Programming
## When to Use This Skill
Use DSPy when you need to:
- **Build complex AI systems** with multiple components and workflows
- **Program LMs declaratively** instead of manual prompt engineering
- **Optimize prompts automatically** using data-driven methods
- **Create modular AI pipelines** that are maintainable and portable
- **Improve model outputs systematically** with optimizers
- **Build RAG systems, agents, or classifiers** with better reliability
**GitHub Stars**: 22,000+ | **Created By**: Stanford NLP
## Installation
```bash
# Stable release
pip install dspy
# Latest development version
pip install git+https://github.com/stanfordnlp/dspy.git
# With specific LM providers
pip install dspy[openai] # OpenAI
pip install dspy[anthropic] # Anthropic Claude
pip install dspy[all] # All providers
```
## Quick Start
### Basic Example: Question Answering
```python
import dspy
# Configure your language model
lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
dspy.settings.configure(lm=lm)
# Define a signature (input → output)
class QA(dspy.Signature):
"""Answer questions with short factual answers."""
question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")
# Create a module
qa = dspy.Predict(QA)
# Use it
response = qa(question="What is the capital of France?")
print(response.answer) # "Paris"
```
### Chain of Thought Reasoning
```python
import dspy
lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
dspy.settings.configure(lm=lm)
# Use ChainOfThought for better reasoning
class MathProblem(dspy.Signature):
"""Solve math word problems."""
problem = dspy.InputField()
answer = dspy.OutputField(desc="numerical answer")
# ChainOfThought generates reasoning steps automatically
cot = dspy.ChainOfThought(MathProblem)
response = cot(problem="If John has 5 apples and gives 2 to Mary, how many does he have?")
print(response.rationale) # Shows reasoning steps
print(response.answer) # "3"
```
## Core Concepts
### 1. Signatures
Signatures define the structure of your AI task (inputs → outputs):
```python
# Inline signature (simple)
qa = dspy.Predict("question -> answer")
# Class signature (detailed)
class Summarize(dspy.Signature):
"""Summarize text into key points."""
text = dspy.InputField()
summary = dspy.OutputField(desc="bullet points, 3-5 items")
summarizer = dspy.ChainOfThought(Summarize)
```
**When to use each:**
- **Inline**: Quick prototyping, simple tasks
- **Class**: Complex tasks, type hints, better documentation
### 2. Modules
Modules are reusable components that transform inputs to outputs:
#### dspy.Predict
Basic prediction module:
```python
predictor = dspy.Predict("context, question -> answer")
result = predictor(context="Paris is the capital of France",
question="What is the capital?")
```
#### dspy.ChainOfThought
Generates reasoning steps before answering:
```python
cot = dspy.ChainOfThought("question -> answer")
result = cot(question="Why is the sky blue?")
print(result.rationale) # Reasoning steps
print(result.answer) # Final answer
```
#### dspy.ReAct
Agent-like reasoning with tools:
```python
from dspy.predict import ReAct
class SearchQA(dspy.Signature):
"""Answer questions using search."""
question = dspy.InputField()
answer = dspy.OutputField()
def search_tool(query: str) -> str:
"""Search Wikipedia."""
# Your search implementation
return results
react = ReAct(SearchQA, tools=[search_tool])
result = react(question="When was Python created?")
```
#### dspy.ProgramOfThought
Generates and executes code for reasoning:
```python
pot = dspy.ProgramOfThought("question -> answer")
result = pot(question="What is 15% of 240?")
# Generates: answer = 240 * 0.15
```
### 3. Optimizers
Optimizers improve your modules automatically using training data:
#### BootstrapFewShot
Learns from examples:
```python
from dspy.teleprompt import BootstrapFewShot
# Training data
trainset = [
dspy.Example(question="What is 2+2?", answer="4").with_inputs("question"),
dspy.Example(question="What is 3+5?", answer="8").with_inputs("question"),
]
# Define metric
def validate_answer(example, pred, trace=None):
return example.answer == pred.answer
# Optimize
optimizer = BootstrapFewShot(metric=validate_answer, max_bootstrapped_demos=3)
optimized_qa = optimizer.compile(qa, trainset=trainset)
# Now optimized_qa performs better!
```
#### MIPRO (Most Important Prompt Optimization)
Iteratively improves prompts:
```python
from dspy.teleprompt import MIPRO
optimizer = MIPRO(
metric=validate_answer,
num_candidates=10,
init_temperature=1.0
)
optimized_cot = optimizer.compile(
cot,
trainset=trainset,
num_trials=100
)
```
#### BootstrapFinetune
Creates datasets for model fine-tuning:
```python
from dspy.teleprompt import BootstrapFinetune
optimizer = BootstrapFinetune(metric=validate_answer)
optimized_module = optimizer.compile(qa, trainset=trainset)
# Exports training data for fine-tuning
```
### 4. Building Complex Systems
#### Multi-Stage Pipeline
```python
import dspy
class MultiHopQA(dspy.Module):
def __init__(self):
super().__init__()
self.retrieve = dspy.Retrieve(k=3)
self.generate_query = dspy.ChainOfThought("question -> search_query")
self.generate_answer = dspy.ChainOfThought("context, question -> answer")
def forward(self, question):
# Stage 1: Generate search query
search_query = self.generate_query(question=question).search_query
# Stage 2: Retrieve context
passages = self.retrieve(search_query).passages
context = "\n".join(passages)
# Stage 3: Generate answer
answer = self.generate_answer(context=context, question=question).answer
return dspy.Prediction(answer=answer, context=context)
# Use the pipeline
qa_system = MultiHopQA()
result = qa_system(question="Who wrote the book that inspired the movie Blade Runner?")
```
#### RAG System with Optimization
```python
import dspy
from dspy.retrieve.chromadb_rm import ChromadbRM
# Configure retriever
retriever = ChromadbRM(
collection_name="documents",
persist_directory="./chroma_db"
)
class RAG(dspy.Module):
def __init__(self, num_passages=3):
super().__init__()
self.retrieve = dspy.Retrieve(k=num_passages)
self.generate = dspy.ChainOfThought("context, question -> answer")
def forward(self, question):
context = self.retrieve(question).passages
return self.generate(context=context, question=question)
# Create and optimize
rag = RAG()
# Optimize with training data
from dspy.teleprompt import BootstrapFewShot
optimizer = BootstrapFewShot(metric=validate_answer)
optimized_rag = optimizer.compile(rag, trainset=trainset)
```
## LM Provider Configuration
### Anthropic Claude
```python
import dspy
lm = dspy.Claude(
model="claude-sonnet-4-5-20250929",
api_key="your-api-key", # Or set ANTHROPIC_API_KEY env var
max_tokens=1000,
temperature=0.7
)
dspy.settings.configure(lm=lm)
```
### OpenAI
```python
lm = dspy.OpenAI(
model="gpt-4",
api_key="your-api-key",
max_tokens=1000
)
dspy.settings.configure(lm=lm)
```
### Local Models (Ollama)
```python
lm = dspy.OllamaLocal(
model="llama3.1",
base_url="http://localhost:11434"
)
dspy.settings.configure(lm=lm)
```
### Multiple Models
```python
# Different models for different tasks
cheap_lm = dspy.OpenAI(model="gpt-3.5-turbo")
strong_lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
# Use cheap model for retrieval, strong model for reasoning
with dspy.settings.context(lm=cheap_lm):
context = retriever(question)
with dspy.settings.context(lm=strong_lm):
answer = generator(context=context, question=question)
```
## Common Patterns
### Pattern 1: Structured Output
```python
from pydantic import BaseModel, Field
class PersonInfo(BaseModel):
name: str = Field(description="Full name")
age: int = Field(description="Age in years")
occupation: str = Field(description="Current job")
class ExtractPerson(dspy.Signature):
"""Extract person information from text."""
text = dspy.InputField()
person: PersonInfo = dspy.OutputField()
extractor = dspy.TypedPredictor(ExtractPerson)
result = extractor(text="John Doe is a 35-year-old software engineer.")
print(result.person.name) # "John Doe"
print(result.person.age) # 35
```
### Pattern 2: Assertion-Driven Optimization
```python
import dspy
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
class MathQA(dspy.Module):
def __init__(self):
super().__init__()
self.solve = dspy.ChainOfThought("problem -> solution: float")
def forward(self, problem):
solution = self.solve(problem=problem).solution
# Assert solution is numeric
dspy.Assert(
isinstance(float(solution), float),
"Solution must be a number",
backtrack=backtrack_handler
)
return dspy.Prediction(solution=solution)
```
### Pattern 3: Self-Consistency
```python
import dspy
from collections import Counter
class ConsistentQA(dspy.Module):
def __init__(self, num_samples=5):
super().__init__()
self.qa = dspy.ChainOfThought("question -> answer")
self.num_samples = num_samples
def forward(self, question):
# Generate multiple answers
answers = []
for _ in range(self.num_samples):
result = self.qa(question=question)
answers.append(result.answer)
# Return most common answer
most_common = Counter(answers).most_common(1)[0][0]
return dspy.Prediction(answer=most_common)
```
### Pattern 4: Retrieval with Reranking
```python
class RerankedRAG(dspy.Module):
def __init__(self):
super().__init__()
self.retrieve = dspy.Retrieve(k=10)
self.rerank = dspy.Predict("question, passage -> relevance_score: float")
self.answer = dspy.ChainOfThought("context, question -> answer")
def forward(self, question):
# Retrieve candidates
passages = self.retrieve(question).passages
# Rerank passages
scored = []
for passage in passages:
score = float(self.rerank(question=question, passage=passage).relevance_score)
scored.append((score, passage))
# Take top 3
top_passages = [p for _, p in sorted(scored, reverse=True)[:3]]
context = "\n\n".join(top_passages)
# Generate answer
return self.answer(context=context, question=question)
```
## Evaluation and Metrics
### Custom Metrics
```python
def exact_match(example, pred, trace=None):
"""Exact match metric."""
return example.answer.lower() == pred.answer.lower()
def f1_score(example, pred, trace=None):
"""F1 score for text overlap."""
pred_tokens = set(pred.answer.lower().split())
gold_tokens = set(example.answer.lower().split())
if not pred_tokens:
return 0.0
precision = len(pred_tokens & gold_tokens) / len(pred_tokens)
recall = len(pred_tokens & gold_tokens) / len(gold_tokens)
if precision + recall == 0:
return 0.0
return 2 * (precision * recall) / (precision + recall)
```
### Evaluation
```python
from dspy.evaluate import Evaluate
# Create evaluator
evaluator = Evaluate(
devset=testset,
metric=exact_match,
num_threads=4,
display_progress=True
)
# Evaluate model
score = evaluator(qa_system)
print(f"Accuracy: {score}")
# Compare optimized vs unoptimized
score_before = evaluator(qa)
score_after = evaluator(optimized_qa)
print(f"Improvement: {score_after - score_before:.2%}")
```
## Best Practices
### 1. Start Simple, Iterate
```python
# Start with Predict
qa = dspy.Predict("question -> answer")
# Add reasoning if needed
qa = dspy.ChainOfThought("question -> answer")
# Add optimization when you have data
optimized_qa = optimizer.compile(qa, trainset=data)
```
### 2. Use Descriptive Signatures
```python
# ❌ Bad: Vague
class Task(dspy.Signature):
input = dspy.InputField()
output = dspy.OutputField()
# ✅ Good: Descriptive
class SummarizeArticle(dspy.Signature):
"""Summarize news articles into 3-5 key points."""
article = dspy.InputField(desc="full article text")
summary = dspy.OutputField(desc="bullet points, 3-5 items")
```
### 3. Optimize with Representative Data
```python
# Create diverse training examples
trainset = [
dspy.Example(question="factual", answer="...).with_inputs("question"),
dspy.Example(question="reasoning", answer="...").with_inputs("question"),
dspy.Example(question="calculation", answer="...").with_inputs("question"),
]
# Use validation set for metric
def metric(example, pred, trace=None):
return example.answer in pred.answer
```
### 4. Save and Load Optimized Models
```python
# Save
optimized_qa.save("models/qa_v1.json")
# Load
loaded_qa = dspy.ChainOfThought("question -> answer")
loaded_qa.load("models/qa_v1.json")
```
### 5. Monitor and Debug
```python
# Enable tracing
dspy.settings.configure(lm=lm, trace=[])
# Run prediction
result = qa(question="...")
# Inspect trace
for call in dspy.settings.trace:
print(f"Prompt: {call['prompt']}")
print(f"Response: {call['response']}")
```
## Comparison to Other Approaches
| Feature | Manual Prompting | LangChain | DSPy |
|---------|-----------------|-----------|------|
| Prompt Engineering | Manual | Manual | Automatic |
| Optimization | Trial & error | None | Data-driven |
| Modularity | Low | Medium | High |
| Type Safety | No | Limited | Yes (Signatures) |
| Portability | Low | Medium | High |
| Learning Curve | Low | Medium | Medium-High |
**When to choose DSPy:**
- You have training data or can generate it
- You need systematic prompt improvement
- You're building complex multi-stage systems
- You want to optimize across different LMs
**When to choose alternatives:**
- Quick prototypes (manual prompting)
- Simple chains with existing tools (LangChain)
- Custom optimization logic needed
## Resources
- **Documentation**: https://dspy.ai
- **GitHub**: https://github.com/stanfordnlp/dspy (22k+ stars)
- **Discord**: https://discord.gg/XCGy2WDCQB
- **Twitter**: @DSPyOSS
- **Paper**: "DSPy: Compiling Declarative Language Model Calls into Self-Improving Pipelines"
## See Also
- `references/modules.md` - Detailed module guide (Predict, ChainOfThought, ReAct, ProgramOfThought)
- `references/optimizers.md` - Optimization algorithms (BootstrapFewShot, MIPRO, BootstrapFinetune)
- `references/examples.md` - Real-world examples (RAG, agents, classifiers)

View file

@ -0,0 +1,663 @@
# DSPy Real-World Examples
Practical examples of building production systems with DSPy.
## Table of Contents
- RAG Systems
- Agent Systems
- Classification
- Data Processing
- Multi-Stage Pipelines
## RAG Systems
### Basic RAG
```python
import dspy
class BasicRAG(dspy.Module):
def __init__(self, num_passages=3):
super().__init__()
self.retrieve = dspy.Retrieve(k=num_passages)
self.generate = dspy.ChainOfThought("context, question -> answer")
def forward(self, question):
passages = self.retrieve(question).passages
context = "\n\n".join(passages)
return self.generate(context=context, question=question)
# Configure retriever (example with Chroma)
from dspy.retrieve.chromadb_rm import ChromadbRM
retriever = ChromadbRM(
collection_name="my_docs",
persist_directory="./chroma_db",
k=3
)
dspy.settings.configure(rm=retriever)
# Use RAG
rag = BasicRAG()
result = rag(question="What is DSPy?")
print(result.answer)
```
### Optimized RAG
```python
from dspy.teleprompt import BootstrapFewShot
# Training data with question-answer pairs
trainset = [
dspy.Example(
question="What is retrieval augmented generation?",
answer="RAG combines retrieval of relevant documents with generation..."
).with_inputs("question"),
# ... more examples
]
# Define metric
def answer_correctness(example, pred, trace=None):
# Check if answer contains key information
return example.answer.lower() in pred.answer.lower()
# Optimize RAG
optimizer = BootstrapFewShot(metric=answer_correctness)
optimized_rag = optimizer.compile(rag, trainset=trainset)
# Optimized RAG performs better on similar questions
result = optimized_rag(question="Explain RAG systems")
```
### Multi-Hop RAG
```python
class MultiHopRAG(dspy.Module):
"""RAG that follows chains of reasoning across documents."""
def __init__(self):
super().__init__()
self.retrieve = dspy.Retrieve(k=3)
self.generate_query = dspy.ChainOfThought("question -> search_query")
self.generate_answer = dspy.ChainOfThought("context, question -> answer")
def forward(self, question):
# First retrieval
query1 = self.generate_query(question=question).search_query
passages1 = self.retrieve(query1).passages
# Generate follow-up query based on first results
context1 = "\n".join(passages1)
query2 = self.generate_query(
question=f"Based on: {context1}\nFollow-up: {question}"
).search_query
# Second retrieval
passages2 = self.retrieve(query2).passages
# Combine all context
all_context = "\n\n".join(passages1 + passages2)
# Generate final answer
return self.generate_answer(context=all_context, question=question)
# Use multi-hop RAG
multi_rag = MultiHopRAG()
result = multi_rag(question="Who wrote the book that inspired Blade Runner?")
# Hop 1: Find "Blade Runner was based on..."
# Hop 2: Find author of that book
```
### RAG with Reranking
```python
class RerankedRAG(dspy.Module):
"""RAG with learned reranking of retrieved passages."""
def __init__(self):
super().__init__()
self.retrieve = dspy.Retrieve(k=10) # Get more candidates
self.rerank = dspy.Predict("question, passage -> relevance_score: float")
self.answer = dspy.ChainOfThought("context, question -> answer")
def forward(self, question):
# Retrieve candidates
passages = self.retrieve(question).passages
# Rerank passages
scored_passages = []
for passage in passages:
score = float(self.rerank(
question=question,
passage=passage
).relevance_score)
scored_passages.append((score, passage))
# Take top 3 after reranking
top_passages = [p for _, p in sorted(scored_passages, reverse=True)[:3]]
context = "\n\n".join(top_passages)
# Generate answer from reranked context
return self.answer(context=context, question=question)
```
## Agent Systems
### ReAct Agent
```python
from dspy.predict import ReAct
# Define tools
def search_wikipedia(query: str) -> str:
"""Search Wikipedia for information."""
import wikipedia
try:
return wikipedia.summary(query, sentences=3)
except:
return "No results found"
def calculate(expression: str) -> str:
"""Evaluate mathematical expression safely."""
try:
# Use safe eval
result = eval(expression, {"__builtins__": {}}, {})
return str(result)
except:
return "Invalid expression"
def search_web(query: str) -> str:
"""Search the web."""
# Your web search implementation
return results
# Create agent signature
class ResearchAgent(dspy.Signature):
"""Answer questions using available tools."""
question = dspy.InputField()
answer = dspy.OutputField()
# Create ReAct agent
agent = ReAct(ResearchAgent, tools=[search_wikipedia, calculate, search_web])
# Agent decides which tools to use
result = agent(question="What is the population of France divided by 10?")
# Agent:
# 1. Thinks: "Need population of France"
# 2. Acts: search_wikipedia("France population")
# 3. Thinks: "Got 67 million, need to divide"
# 4. Acts: calculate("67000000 / 10")
# 5. Returns: "6,700,000"
```
### Multi-Agent System
```python
class MultiAgentSystem(dspy.Module):
"""System with specialized agents for different tasks."""
def __init__(self):
super().__init__()
# Router agent
self.router = dspy.Predict("question -> agent_type: str")
# Specialized agents
self.research_agent = ReAct(
ResearchAgent,
tools=[search_wikipedia, search_web]
)
self.math_agent = dspy.ProgramOfThought("problem -> answer")
self.reasoning_agent = dspy.ChainOfThought("question -> answer")
def forward(self, question):
# Route to appropriate agent
agent_type = self.router(question=question).agent_type
if agent_type == "research":
return self.research_agent(question=question)
elif agent_type == "math":
return self.math_agent(problem=question)
else:
return self.reasoning_agent(question=question)
# Use multi-agent system
mas = MultiAgentSystem()
result = mas(question="What is 15% of the GDP of France?")
# Routes to research_agent for GDP, then to math_agent for calculation
```
## Classification
### Binary Classifier
```python
class SentimentClassifier(dspy.Module):
def __init__(self):
super().__init__()
self.classify = dspy.Predict("text -> sentiment: str")
def forward(self, text):
return self.classify(text=text)
# Training data
trainset = [
dspy.Example(text="I love this!", sentiment="positive").with_inputs("text"),
dspy.Example(text="Terrible experience", sentiment="negative").with_inputs("text"),
# ... more examples
]
# Optimize
def accuracy(example, pred, trace=None):
return example.sentiment == pred.sentiment
optimizer = BootstrapFewShot(metric=accuracy, max_bootstrapped_demos=5)
classifier = SentimentClassifier()
optimized_classifier = optimizer.compile(classifier, trainset=trainset)
# Use classifier
result = optimized_classifier(text="This product is amazing!")
print(result.sentiment) # "positive"
```
### Multi-Class Classifier
```python
class TopicClassifier(dspy.Module):
def __init__(self):
super().__init__()
self.classify = dspy.ChainOfThought(
"text -> category: str, confidence: float"
)
def forward(self, text):
result = self.classify(text=text)
return dspy.Prediction(
category=result.category,
confidence=float(result.confidence)
)
# Define categories in signature
class TopicSignature(dspy.Signature):
"""Classify text into one of: technology, sports, politics, entertainment."""
text = dspy.InputField()
category = dspy.OutputField(desc="one of: technology, sports, politics, entertainment")
confidence = dspy.OutputField(desc="0.0 to 1.0")
classifier = dspy.ChainOfThought(TopicSignature)
result = classifier(text="The Lakers won the championship")
print(result.category) # "sports"
print(result.confidence) # 0.95
```
### Hierarchical Classifier
```python
class HierarchicalClassifier(dspy.Module):
"""Two-stage classification: coarse then fine-grained."""
def __init__(self):
super().__init__()
self.coarse = dspy.Predict("text -> broad_category: str")
self.fine_tech = dspy.Predict("text -> tech_subcategory: str")
self.fine_sports = dspy.Predict("text -> sports_subcategory: str")
def forward(self, text):
# Stage 1: Broad category
broad = self.coarse(text=text).broad_category
# Stage 2: Fine-grained based on broad
if broad == "technology":
fine = self.fine_tech(text=text).tech_subcategory
elif broad == "sports":
fine = self.fine_sports(text=text).sports_subcategory
else:
fine = "other"
return dspy.Prediction(broad_category=broad, fine_category=fine)
```
## Data Processing
### Text Summarization
```python
class AdaptiveSummarizer(dspy.Module):
"""Summarizes text to target length."""
def __init__(self):
super().__init__()
self.summarize = dspy.ChainOfThought("text, target_length -> summary")
def forward(self, text, target_length="3 sentences"):
return self.summarize(text=text, target_length=target_length)
# Use summarizer
summarizer = AdaptiveSummarizer()
long_text = "..." # Long article
short_summary = summarizer(long_text, target_length="1 sentence")
medium_summary = summarizer(long_text, target_length="3 sentences")
detailed_summary = summarizer(long_text, target_length="1 paragraph")
```
### Information Extraction
```python
from pydantic import BaseModel, Field
class PersonInfo(BaseModel):
name: str = Field(description="Full name")
age: int = Field(description="Age in years")
occupation: str = Field(description="Job title")
location: str = Field(description="City and country")
class ExtractPerson(dspy.Signature):
"""Extract person information from text."""
text = dspy.InputField()
person: PersonInfo = dspy.OutputField()
extractor = dspy.TypedPredictor(ExtractPerson)
text = "Dr. Jane Smith, 42, is a neuroscientist at Stanford University in Palo Alto, California."
result = extractor(text=text)
print(result.person.name) # "Dr. Jane Smith"
print(result.person.age) # 42
print(result.person.occupation) # "neuroscientist"
print(result.person.location) # "Palo Alto, California"
```
### Batch Processing
```python
class BatchProcessor(dspy.Module):
"""Process large datasets efficiently."""
def __init__(self):
super().__init__()
self.process = dspy.Predict("text -> processed_text")
def forward(self, texts):
# Batch processing for efficiency
return self.process.batch([{"text": t} for t in texts])
# Process 1000 documents
processor = BatchProcessor()
results = processor(texts=large_dataset)
# Results are returned in order
for original, result in zip(large_dataset, results):
print(f"{original} -> {result.processed_text}")
```
## Multi-Stage Pipelines
### Document Processing Pipeline
```python
class DocumentPipeline(dspy.Module):
"""Multi-stage document processing."""
def __init__(self):
super().__init__()
self.extract = dspy.Predict("document -> key_points")
self.classify = dspy.Predict("key_points -> category")
self.summarize = dspy.ChainOfThought("key_points, category -> summary")
self.tag = dspy.Predict("summary -> tags")
def forward(self, document):
# Stage 1: Extract key points
key_points = self.extract(document=document).key_points
# Stage 2: Classify
category = self.classify(key_points=key_points).category
# Stage 3: Summarize
summary = self.summarize(
key_points=key_points,
category=category
).summary
# Stage 4: Generate tags
tags = self.tag(summary=summary).tags
return dspy.Prediction(
key_points=key_points,
category=category,
summary=summary,
tags=tags
)
```
### Quality Control Pipeline
```python
class QualityControlPipeline(dspy.Module):
"""Generate output and verify quality."""
def __init__(self):
super().__init__()
self.generate = dspy.ChainOfThought("prompt -> output")
self.verify = dspy.Predict("output -> is_valid: bool, issues: str")
self.improve = dspy.ChainOfThought("output, issues -> improved_output")
def forward(self, prompt, max_iterations=3):
output = self.generate(prompt=prompt).output
for _ in range(max_iterations):
# Verify output
verification = self.verify(output=output)
if verification.is_valid:
return dspy.Prediction(output=output, iterations=_ + 1)
# Improve based on issues
output = self.improve(
output=output,
issues=verification.issues
).improved_output
return dspy.Prediction(output=output, iterations=max_iterations)
```
## Production Tips
### 1. Caching for Performance
```python
from functools import lru_cache
class CachedRAG(dspy.Module):
def __init__(self):
super().__init__()
self.retrieve = dspy.Retrieve(k=3)
self.generate = dspy.ChainOfThought("context, question -> answer")
@lru_cache(maxsize=1000)
def forward(self, question):
passages = self.retrieve(question).passages
context = "\n".join(passages)
return self.generate(context=context, question=question).answer
```
### 2. Error Handling
```python
class RobustModule(dspy.Module):
def __init__(self):
super().__init__()
self.process = dspy.ChainOfThought("input -> output")
def forward(self, input):
try:
result = self.process(input=input)
return result
except Exception as e:
# Log error
print(f"Error processing {input}: {e}")
# Return fallback
return dspy.Prediction(output="Error: could not process input")
```
### 3. Monitoring
```python
class MonitoredModule(dspy.Module):
def __init__(self):
super().__init__()
self.process = dspy.ChainOfThought("input -> output")
self.call_count = 0
self.errors = 0
def forward(self, input):
self.call_count += 1
try:
result = self.process(input=input)
return result
except Exception as e:
self.errors += 1
raise
def get_stats(self):
return {
"calls": self.call_count,
"errors": self.errors,
"error_rate": self.errors / max(self.call_count, 1)
}
```
### 4. A/B Testing
```python
class ABTestModule(dspy.Module):
"""Run two variants and compare."""
def __init__(self, variant_a, variant_b):
super().__init__()
self.variant_a = variant_a
self.variant_b = variant_b
self.a_calls = 0
self.b_calls = 0
def forward(self, input, variant="a"):
if variant == "a":
self.a_calls += 1
return self.variant_a(input=input)
else:
self.b_calls += 1
return self.variant_b(input=input)
# Compare two optimizers
baseline = dspy.ChainOfThought("question -> answer")
optimized = BootstrapFewShot(...).compile(baseline, trainset=trainset)
ab_test = ABTestModule(variant_a=baseline, variant_b=optimized)
# Route 50% to each
import random
variant = "a" if random.random() < 0.5 else "b"
result = ab_test(input=question, variant=variant)
```
## Complete Example: Customer Support Bot
```python
import dspy
from dspy.teleprompt import BootstrapFewShot
class CustomerSupportBot(dspy.Module):
"""Complete customer support system."""
def __init__(self):
super().__init__()
# Classify intent
self.classify_intent = dspy.Predict("message -> intent: str")
# Specialized handlers
self.technical_handler = dspy.ChainOfThought("message, history -> response")
self.billing_handler = dspy.ChainOfThought("message, history -> response")
self.general_handler = dspy.Predict("message, history -> response")
# Retrieve relevant docs
self.retrieve = dspy.Retrieve(k=3)
# Conversation history
self.history = []
def forward(self, message):
# Classify intent
intent = self.classify_intent(message=message).intent
# Retrieve relevant documentation
docs = self.retrieve(message).passages
context = "\n".join(docs)
# Add context to history
history_str = "\n".join(self.history)
full_message = f"Context: {context}\n\nMessage: {message}"
# Route to appropriate handler
if intent == "technical":
response = self.technical_handler(
message=full_message,
history=history_str
).response
elif intent == "billing":
response = self.billing_handler(
message=full_message,
history=history_str
).response
else:
response = self.general_handler(
message=full_message,
history=history_str
).response
# Update history
self.history.append(f"User: {message}")
self.history.append(f"Bot: {response}")
return dspy.Prediction(response=response, intent=intent)
# Training data
trainset = [
dspy.Example(
message="My account isn't working",
intent="technical",
response="I'd be happy to help. What error are you seeing?"
).with_inputs("message"),
# ... more examples
]
# Define metric
def response_quality(example, pred, trace=None):
# Check if response is helpful
if len(pred.response) < 20:
return 0.0
if example.intent != pred.intent:
return 0.3
return 1.0
# Optimize
optimizer = BootstrapFewShot(metric=response_quality)
bot = CustomerSupportBot()
optimized_bot = optimizer.compile(bot, trainset=trainset)
# Use in production
optimized_bot.save("models/support_bot_v1.json")
# Later, load and use
loaded_bot = CustomerSupportBot()
loaded_bot.load("models/support_bot_v1.json")
response = loaded_bot(message="I can't log in")
```
## Resources
- **Documentation**: https://dspy.ai
- **Examples Repo**: https://github.com/stanfordnlp/dspy/tree/main/examples
- **Discord**: https://discord.gg/XCGy2WDCQB

View file

@ -0,0 +1,475 @@
# DSPy Modules
Complete guide to DSPy's built-in modules for language model programming.
## Module Basics
DSPy modules are composable building blocks inspired by PyTorch's NN modules:
- Have learnable parameters (prompts, few-shot examples)
- Can be composed using Python control flow
- Generalized to handle any signature
- Optimizable with DSPy optimizers
### Base Module Pattern
```python
import dspy
class CustomModule(dspy.Module):
def __init__(self):
super().__init__()
# Initialize sub-modules
self.predictor = dspy.Predict("input -> output")
def forward(self, input):
# Module logic
result = self.predictor(input=input)
return result
```
## Core Modules
### dspy.Predict
**Basic prediction module** - Makes LM calls without reasoning steps.
```python
# Inline signature
qa = dspy.Predict("question -> answer")
result = qa(question="What is 2+2?")
# Class signature
class QA(dspy.Signature):
"""Answer questions concisely."""
question = dspy.InputField()
answer = dspy.OutputField(desc="short, factual answer")
qa = dspy.Predict(QA)
result = qa(question="What is the capital of France?")
print(result.answer) # "Paris"
```
**When to use:**
- Simple, direct predictions
- No reasoning steps needed
- Fast responses required
### dspy.ChainOfThought
**Step-by-step reasoning** - Generates rationale before answer.
**Parameters:**
- `signature`: Task signature
- `rationale_field`: Custom reasoning field (optional)
- `rationale_field_type`: Type for rationale (default: `str`)
```python
# Basic usage
cot = dspy.ChainOfThought("question -> answer")
result = cot(question="If I have 5 apples and give away 2, how many remain?")
print(result.rationale) # "Let's think step by step..."
print(result.answer) # "3"
# Custom rationale field
cot = dspy.ChainOfThought(
signature="problem -> solution",
rationale_field=dspy.OutputField(
prefix="Reasoning: Let's break this down step by step to"
)
)
```
**When to use:**
- Complex reasoning tasks
- Math word problems
- Logical deduction
- Quality > speed
**Performance:**
- ~2x slower than Predict
- Significantly better accuracy on reasoning tasks
### dspy.ProgramOfThought
**Code-based reasoning** - Generates and executes Python code.
```python
pot = dspy.ProgramOfThought("question -> answer")
result = pot(question="What is 15% of 240?")
# Internally generates: answer = 240 * 0.15
# Executes code and returns result
print(result.answer) # 36.0
result = pot(question="If a train travels 60 mph for 2.5 hours, how far does it go?")
# Generates: distance = 60 * 2.5
print(result.answer) # 150.0
```
**When to use:**
- Arithmetic calculations
- Symbolic math
- Data transformations
- Deterministic computations
**Benefits:**
- More reliable than text-based math
- Handles complex calculations
- Transparent (shows generated code)
### dspy.ReAct
**Reasoning + Acting** - Agent that uses tools iteratively.
```python
from dspy.predict import ReAct
# Define tools
def search_wikipedia(query: str) -> str:
"""Search Wikipedia for information."""
# Your search implementation
return search_results
def calculate(expression: str) -> float:
"""Evaluate a mathematical expression."""
return eval(expression)
# Create ReAct agent
class ResearchQA(dspy.Signature):
"""Answer questions using available tools."""
question = dspy.InputField()
answer = dspy.OutputField()
react = ReAct(ResearchQA, tools=[search_wikipedia, calculate])
# Agent decides which tools to use
result = react(question="How old was Einstein when he published special relativity?")
# Internally:
# 1. Thinks: "Need birth year and publication year"
# 2. Acts: search_wikipedia("Albert Einstein")
# 3. Acts: search_wikipedia("Special relativity 1905")
# 4. Acts: calculate("1905 - 1879")
# 5. Returns: "26 years old"
```
**When to use:**
- Multi-step research tasks
- Tool-using agents
- Complex information retrieval
- Tasks requiring multiple API calls
**Best practices:**
- Keep tool descriptions clear and specific
- Limit to 5-7 tools (too many = confusion)
- Provide tool usage examples in docstrings
### dspy.MultiChainComparison
**Generate multiple outputs and compare** - Self-consistency pattern.
```python
mcc = dspy.MultiChainComparison("question -> answer", M=5)
result = mcc(question="What is the capital of France?")
# Generates 5 candidate answers
# Compares and selects most consistent
print(result.answer) # "Paris"
print(result.candidates) # All 5 generated answers
```
**Parameters:**
- `M`: Number of candidates to generate (default: 5)
- `temperature`: Sampling temperature for diversity
**When to use:**
- High-stakes decisions
- Ambiguous questions
- When single answer may be unreliable
**Tradeoff:**
- M times slower (M parallel calls)
- Higher accuracy on ambiguous tasks
### dspy.majority
**Majority voting over multiple predictions.**
```python
from dspy.primitives import majority
# Generate multiple predictions
predictor = dspy.Predict("question -> answer")
predictions = [predictor(question="What is 2+2?") for _ in range(5)]
# Take majority vote
answer = majority([p.answer for p in predictions])
print(answer) # "4"
```
**When to use:**
- Combining multiple model outputs
- Reducing variance in predictions
- Ensemble approaches
## Advanced Modules
### dspy.TypedPredictor
**Structured output with Pydantic models.**
```python
from pydantic import BaseModel, Field
class PersonInfo(BaseModel):
name: str = Field(description="Full name")
age: int = Field(description="Age in years")
occupation: str = Field(description="Current job")
class ExtractPerson(dspy.Signature):
"""Extract person information from text."""
text = dspy.InputField()
person: PersonInfo = dspy.OutputField()
extractor = dspy.TypedPredictor(ExtractPerson)
result = extractor(text="John Doe is a 35-year-old software engineer.")
print(result.person.name) # "John Doe"
print(result.person.age) # 35
print(result.person.occupation) # "software engineer"
```
**Benefits:**
- Type safety
- Automatic validation
- JSON schema generation
- IDE autocomplete
### dspy.Retry
**Automatic retry with validation.**
```python
from dspy.primitives import Retry
def validate_number(example, pred, trace=None):
"""Validate output is a number."""
try:
float(pred.answer)
return True
except ValueError:
return False
# Retry up to 3 times if validation fails
qa = Retry(
dspy.ChainOfThought("question -> answer"),
validate=validate_number,
max_retries=3
)
result = qa(question="What is 15% of 80?")
# If first attempt returns non-numeric, retries automatically
```
### dspy.Assert
**Assertion-driven optimization.**
```python
import dspy
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
class ValidatedQA(dspy.Module):
def __init__(self):
super().__init__()
self.qa = dspy.ChainOfThought("question -> answer: float")
def forward(self, question):
answer = self.qa(question=question).answer
# Assert answer is numeric
dspy.Assert(
isinstance(float(answer), float),
"Answer must be a number",
backtrack=backtrack_handler
)
return dspy.Prediction(answer=answer)
```
**Benefits:**
- Catches errors during optimization
- Guides LM toward valid outputs
- Better than post-hoc filtering
## Module Composition
### Sequential Pipeline
```python
class Pipeline(dspy.Module):
def __init__(self):
super().__init__()
self.stage1 = dspy.Predict("input -> intermediate")
self.stage2 = dspy.ChainOfThought("intermediate -> output")
def forward(self, input):
intermediate = self.stage1(input=input).intermediate
output = self.stage2(intermediate=intermediate).output
return dspy.Prediction(output=output)
```
### Conditional Logic
```python
class ConditionalModule(dspy.Module):
def __init__(self):
super().__init__()
self.router = dspy.Predict("question -> category: str")
self.simple_qa = dspy.Predict("question -> answer")
self.complex_qa = dspy.ChainOfThought("question -> answer")
def forward(self, question):
category = self.router(question=question).category
if category == "simple":
return self.simple_qa(question=question)
else:
return self.complex_qa(question=question)
```
### Parallel Execution
```python
class ParallelModule(dspy.Module):
def __init__(self):
super().__init__()
self.approach1 = dspy.ChainOfThought("question -> answer")
self.approach2 = dspy.ProgramOfThought("question -> answer")
def forward(self, question):
# Run both approaches
answer1 = self.approach1(question=question).answer
answer2 = self.approach2(question=question).answer
# Compare or combine results
if answer1 == answer2:
return dspy.Prediction(answer=answer1, confidence="high")
else:
return dspy.Prediction(answer=answer1, confidence="low")
```
## Batch Processing
All modules support batch processing for efficiency:
```python
cot = dspy.ChainOfThought("question -> answer")
questions = [
"What is 2+2?",
"What is 3+3?",
"What is 4+4?"
]
# Process all at once
results = cot.batch([{"question": q} for q in questions])
for result in results:
print(result.answer)
```
## Saving and Loading
```python
# Save module
qa = dspy.ChainOfThought("question -> answer")
qa.save("models/qa_v1.json")
# Load module
loaded_qa = dspy.ChainOfThought("question -> answer")
loaded_qa.load("models/qa_v1.json")
```
**What gets saved:**
- Few-shot examples
- Prompt instructions
- Module configuration
**What doesn't get saved:**
- Model weights (DSPy doesn't fine-tune by default)
- LM provider configuration
## Module Selection Guide
| Task | Module | Reason |
|------|--------|--------|
| Simple classification | Predict | Fast, direct |
| Math word problems | ProgramOfThought | Reliable calculations |
| Logical reasoning | ChainOfThought | Better with steps |
| Multi-step research | ReAct | Tool usage |
| High-stakes decisions | MultiChainComparison | Self-consistency |
| Structured extraction | TypedPredictor | Type safety |
| Ambiguous questions | MultiChainComparison | Multiple perspectives |
## Performance Tips
1. **Start with Predict**, add reasoning only if needed
2. **Use batch processing** for multiple inputs
3. **Cache predictions** for repeated queries
4. **Profile token usage** with `track_usage=True`
5. **Optimize after prototyping** with teleprompters
## Common Patterns
### Pattern: Retrieval + Generation
```python
class RAG(dspy.Module):
def __init__(self, k=3):
super().__init__()
self.retrieve = dspy.Retrieve(k=k)
self.generate = dspy.ChainOfThought("context, question -> answer")
def forward(self, question):
context = self.retrieve(question).passages
return self.generate(context=context, question=question)
```
### Pattern: Verification Loop
```python
class VerifiedQA(dspy.Module):
def __init__(self):
super().__init__()
self.answer = dspy.ChainOfThought("question -> answer")
self.verify = dspy.Predict("question, answer -> is_correct: bool")
def forward(self, question, max_attempts=3):
for _ in range(max_attempts):
answer = self.answer(question=question).answer
is_correct = self.verify(question=question, answer=answer).is_correct
if is_correct:
return dspy.Prediction(answer=answer)
return dspy.Prediction(answer="Unable to verify answer")
```
### Pattern: Multi-Turn Dialog
```python
class DialogAgent(dspy.Module):
def __init__(self):
super().__init__()
self.respond = dspy.Predict("history, user_message -> assistant_message")
self.history = []
def forward(self, user_message):
history_str = "\n".join(self.history)
response = self.respond(history=history_str, user_message=user_message)
self.history.append(f"User: {user_message}")
self.history.append(f"Assistant: {response.assistant_message}")
return response
```

View file

@ -0,0 +1,566 @@
# DSPy Optimizers (Teleprompters)
Complete guide to DSPy's optimization algorithms for improving prompts and model weights.
## What are Optimizers?
DSPy optimizers (called "teleprompters") automatically improve your modules by:
- **Synthesizing few-shot examples** from training data
- **Proposing better instructions** through search
- **Fine-tuning model weights** (optional)
**Key idea**: Instead of manually tuning prompts, define a metric and let DSPy optimize.
## Optimizer Selection Guide
| Optimizer | Best For | Speed | Quality | Data Needed |
|-----------|----------|-------|---------|-------------|
| BootstrapFewShot | General purpose | Fast | Good | 10-50 examples |
| MIPRO | Instruction tuning | Medium | Excellent | 50-200 examples |
| BootstrapFinetune | Fine-tuning | Slow | Excellent | 100+ examples |
| COPRO | Prompt optimization | Medium | Good | 20-100 examples |
| KNNFewShot | Quick baseline | Very fast | Fair | 10+ examples |
## Core Optimizers
### BootstrapFewShot
**Most popular optimizer** - Generates few-shot demonstrations from training data.
**How it works:**
1. Takes your training examples
2. Uses your module to generate predictions
3. Selects high-quality predictions (based on metric)
4. Uses these as few-shot examples in future prompts
**Parameters:**
- `metric`: Function that scores predictions (required)
- `max_bootstrapped_demos`: Max demonstrations to generate (default: 4)
- `max_labeled_demos`: Max labeled examples to use (default: 16)
- `max_rounds`: Optimization iterations (default: 1)
- `metric_threshold`: Minimum score to accept (optional)
```python
import dspy
from dspy.teleprompt import BootstrapFewShot
# Define metric
def validate_answer(example, pred, trace=None):
"""Return True if prediction matches gold answer."""
return example.answer.lower() == pred.answer.lower()
# Training data
trainset = [
dspy.Example(question="What is 2+2?", answer="4").with_inputs("question"),
dspy.Example(question="What is 3+5?", answer="8").with_inputs("question"),
dspy.Example(question="What is 10-3?", answer="7").with_inputs("question"),
]
# Create module
qa = dspy.ChainOfThought("question -> answer")
# Optimize
optimizer = BootstrapFewShot(
metric=validate_answer,
max_bootstrapped_demos=3,
max_rounds=2
)
optimized_qa = optimizer.compile(qa, trainset=trainset)
# Now optimized_qa has learned few-shot examples!
result = optimized_qa(question="What is 5+7?")
```
**Best practices:**
- Start with 10-50 training examples
- Use diverse examples covering edge cases
- Set `max_bootstrapped_demos=3-5` for most tasks
- Increase `max_rounds=2-3` for better quality
**When to use:**
- First optimizer to try
- You have 10+ labeled examples
- Want quick improvements
- General-purpose tasks
### MIPRO (Most Important Prompt Optimization)
**State-of-the-art optimizer** - Iteratively searches for better instructions.
**How it works:**
1. Generates candidate instructions
2. Tests each on validation set
3. Selects best-performing instructions
4. Iterates to refine further
**Parameters:**
- `metric`: Evaluation metric (required)
- `num_candidates`: Instructions to try per iteration (default: 10)
- `init_temperature`: Sampling temperature (default: 1.0)
- `verbose`: Show progress (default: False)
```python
from dspy.teleprompt import MIPRO
# Define metric with more nuance
def answer_quality(example, pred, trace=None):
"""Score answer quality 0-1."""
if example.answer.lower() in pred.answer.lower():
return 1.0
# Partial credit for similar answers
return 0.5 if len(set(example.answer.split()) & set(pred.answer.split())) > 0 else 0.0
# Larger training set (MIPRO benefits from more data)
trainset = [...] # 50-200 examples
valset = [...] # 20-50 examples
# Create module
qa = dspy.ChainOfThought("question -> answer")
# Optimize with MIPRO
optimizer = MIPRO(
metric=answer_quality,
num_candidates=10,
init_temperature=1.0,
verbose=True
)
optimized_qa = optimizer.compile(
student=qa,
trainset=trainset,
valset=valset, # MIPRO uses separate validation set
num_trials=100 # More trials = better quality
)
```
**Best practices:**
- Use 50-200 training examples
- Separate validation set (20-50 examples)
- Run 100-200 trials for best results
- Takes 10-30 minutes typically
**When to use:**
- You have 50+ labeled examples
- Want state-of-the-art performance
- Willing to wait for optimization
- Complex reasoning tasks
### BootstrapFinetune
**Fine-tune model weights** - Creates training dataset for fine-tuning.
**How it works:**
1. Generates synthetic training data
2. Exports data in fine-tuning format
3. You fine-tune model separately
4. Load fine-tuned model back
**Parameters:**
- `metric`: Evaluation metric (required)
- `max_bootstrapped_demos`: Demonstrations to generate (default: 4)
- `max_rounds`: Data generation rounds (default: 1)
```python
from dspy.teleprompt import BootstrapFinetune
# Training data
trainset = [...] # 100+ examples recommended
# Define metric
def validate(example, pred, trace=None):
return example.answer == pred.answer
# Create module
qa = dspy.ChainOfThought("question -> answer")
# Generate fine-tuning data
optimizer = BootstrapFinetune(metric=validate)
optimized_qa = optimizer.compile(qa, trainset=trainset)
# Exports training data to file
# You then fine-tune using your LM provider's API
# After fine-tuning, load your model:
finetuned_lm = dspy.OpenAI(model="ft:gpt-3.5-turbo:your-model-id")
dspy.settings.configure(lm=finetuned_lm)
```
**Best practices:**
- Use 100+ training examples
- Validate on held-out test set
- Monitor for overfitting
- Compare with prompt-based methods first
**When to use:**
- You have 100+ examples
- Latency is critical (fine-tuned models faster)
- Task is narrow and well-defined
- Prompt optimization isn't enough
### COPRO (Coordinate Prompt Optimization)
**Optimize prompts via gradient-free search.**
**How it works:**
1. Generates prompt variants
2. Evaluates each variant
3. Selects best prompts
4. Iterates to refine
```python
from dspy.teleprompt import COPRO
# Training data
trainset = [...]
# Define metric
def metric(example, pred, trace=None):
return example.answer == pred.answer
# Create module
qa = dspy.ChainOfThought("question -> answer")
# Optimize with COPRO
optimizer = COPRO(
metric=metric,
breadth=10, # Candidates per iteration
depth=3 # Optimization rounds
)
optimized_qa = optimizer.compile(qa, trainset=trainset)
```
**When to use:**
- Want prompt optimization
- Have 20-100 examples
- MIPRO too slow
### KNNFewShot
**Simple k-nearest neighbors** - Selects similar examples for each query.
**How it works:**
1. Embeds all training examples
2. For each query, finds k most similar examples
3. Uses these as few-shot demonstrations
```python
from dspy.teleprompt import KNNFewShot
trainset = [...]
# No metric needed - just selects similar examples
optimizer = KNNFewShot(k=3)
optimized_qa = optimizer.compile(qa, trainset=trainset)
# For each query, uses 3 most similar examples from trainset
```
**When to use:**
- Quick baseline
- Have diverse training examples
- Similarity is good proxy for helpfulness
## Writing Metrics
Metrics are functions that score predictions. They're critical for optimization.
### Binary Metrics
```python
def exact_match(example, pred, trace=None):
"""Return True if prediction exactly matches gold."""
return example.answer == pred.answer
def contains_answer(example, pred, trace=None):
"""Return True if prediction contains gold answer."""
return example.answer.lower() in pred.answer.lower()
```
### Continuous Metrics
```python
def f1_score(example, pred, trace=None):
"""F1 score between prediction and gold."""
pred_tokens = set(pred.answer.lower().split())
gold_tokens = set(example.answer.lower().split())
if not pred_tokens:
return 0.0
precision = len(pred_tokens & gold_tokens) / len(pred_tokens)
recall = len(pred_tokens & gold_tokens) / len(gold_tokens)
if precision + recall == 0:
return 0.0
return 2 * (precision * recall) / (precision + recall)
def semantic_similarity(example, pred, trace=None):
"""Embedding similarity between prediction and gold."""
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')
emb1 = model.encode(example.answer)
emb2 = model.encode(pred.answer)
similarity = cosine_similarity(emb1, emb2)
return similarity
```
### Multi-Factor Metrics
```python
def comprehensive_metric(example, pred, trace=None):
"""Combine multiple factors."""
score = 0.0
# Correctness (50%)
if example.answer.lower() in pred.answer.lower():
score += 0.5
# Conciseness (25%)
if len(pred.answer.split()) <= 20:
score += 0.25
# Citation (25%)
if "source:" in pred.answer.lower():
score += 0.25
return score
```
### Using Trace for Debugging
```python
def metric_with_trace(example, pred, trace=None):
"""Metric that uses trace for debugging."""
is_correct = example.answer == pred.answer
if trace is not None and not is_correct:
# Log failures for analysis
print(f"Failed on: {example.question}")
print(f"Expected: {example.answer}")
print(f"Got: {pred.answer}")
return is_correct
```
## Evaluation Best Practices
### Train/Val/Test Split
```python
# Split data
trainset = data[:100] # 70%
valset = data[100:120] # 15%
testset = data[120:] # 15%
# Optimize on train
optimized = optimizer.compile(module, trainset=trainset)
# Validate during optimization (for MIPRO)
optimized = optimizer.compile(module, trainset=trainset, valset=valset)
# Evaluate on test
from dspy.evaluate import Evaluate
evaluator = Evaluate(devset=testset, metric=metric)
score = evaluator(optimized)
```
### Cross-Validation
```python
from sklearn.model_selection import KFold
kfold = KFold(n_splits=5)
scores = []
for train_idx, val_idx in kfold.split(data):
trainset = [data[i] for i in train_idx]
valset = [data[i] for i in val_idx]
optimized = optimizer.compile(module, trainset=trainset)
score = evaluator(optimized, devset=valset)
scores.append(score)
print(f"Average score: {sum(scores) / len(scores):.2f}")
```
### Comparing Optimizers
```python
results = {}
for opt_name, optimizer in [
("baseline", None),
("fewshot", BootstrapFewShot(metric=metric)),
("mipro", MIPRO(metric=metric)),
]:
if optimizer is None:
module_opt = module
else:
module_opt = optimizer.compile(module, trainset=trainset)
score = evaluator(module_opt, devset=testset)
results[opt_name] = score
print(results)
# {'baseline': 0.65, 'fewshot': 0.78, 'mipro': 0.85}
```
## Advanced Patterns
### Custom Optimizer
```python
from dspy.teleprompt import Teleprompter
class CustomOptimizer(Teleprompter):
def __init__(self, metric):
self.metric = metric
def compile(self, student, trainset, **kwargs):
# Your optimization logic here
# Return optimized student module
return student
```
### Multi-Stage Optimization
```python
# Stage 1: Bootstrap few-shot
stage1 = BootstrapFewShot(metric=metric, max_bootstrapped_demos=3)
optimized1 = stage1.compile(module, trainset=trainset)
# Stage 2: Instruction tuning
stage2 = MIPRO(metric=metric, num_candidates=10)
optimized2 = stage2.compile(optimized1, trainset=trainset, valset=valset)
# Final optimized module
final_module = optimized2
```
### Ensemble Optimization
```python
class EnsembleModule(dspy.Module):
def __init__(self, modules):
super().__init__()
self.modules = modules
def forward(self, question):
predictions = [m(question=question).answer for m in self.modules]
# Vote or average
return dspy.Prediction(answer=max(set(predictions), key=predictions.count))
# Optimize multiple modules
opt1 = BootstrapFewShot(metric=metric).compile(module, trainset=trainset)
opt2 = MIPRO(metric=metric).compile(module, trainset=trainset)
opt3 = COPRO(metric=metric).compile(module, trainset=trainset)
# Ensemble
ensemble = EnsembleModule([opt1, opt2, opt3])
```
## Optimization Workflow
### 1. Start with Baseline
```python
# No optimization
baseline = dspy.ChainOfThought("question -> answer")
baseline_score = evaluator(baseline, devset=testset)
print(f"Baseline: {baseline_score}")
```
### 2. Try BootstrapFewShot
```python
# Quick optimization
fewshot = BootstrapFewShot(metric=metric, max_bootstrapped_demos=3)
optimized = fewshot.compile(baseline, trainset=trainset)
fewshot_score = evaluator(optimized, devset=testset)
print(f"Few-shot: {fewshot_score} (+{fewshot_score - baseline_score:.2f})")
```
### 3. If More Data Available, Try MIPRO
```python
# State-of-the-art optimization
mipro = MIPRO(metric=metric, num_candidates=10)
optimized_mipro = mipro.compile(baseline, trainset=trainset, valset=valset)
mipro_score = evaluator(optimized_mipro, devset=testset)
print(f"MIPRO: {mipro_score} (+{mipro_score - baseline_score:.2f})")
```
### 4. Save Best Model
```python
if mipro_score > fewshot_score:
optimized_mipro.save("models/best_model.json")
else:
optimized.save("models/best_model.json")
```
## Common Pitfalls
### 1. Overfitting to Training Data
```python
# ❌ Bad: Too many demos
optimizer = BootstrapFewShot(max_bootstrapped_demos=20) # Overfits!
# ✅ Good: Moderate demos
optimizer = BootstrapFewShot(max_bootstrapped_demos=3-5)
```
### 2. Metric Doesn't Match Task
```python
# ❌ Bad: Binary metric for nuanced task
def bad_metric(example, pred, trace=None):
return example.answer == pred.answer # Too strict!
# ✅ Good: Graded metric
def good_metric(example, pred, trace=None):
return f1_score(example.answer, pred.answer) # Allows partial credit
```
### 3. Insufficient Training Data
```python
# ❌ Bad: Too little data
trainset = data[:5] # Not enough!
# ✅ Good: Sufficient data
trainset = data[:50] # Better
```
### 4. No Validation Set
```python
# ❌ Bad: Optimizing on test set
optimizer.compile(module, trainset=testset) # Cheating!
# ✅ Good: Proper splits
optimizer.compile(module, trainset=trainset, valset=valset)
evaluator(optimized, devset=testset)
```
## Performance Tips
1. **Start simple**: BootstrapFewShot first
2. **Use representative data**: Cover edge cases
3. **Monitor overfitting**: Validate on held-out set
4. **Iterate metrics**: Refine based on failures
5. **Save checkpoints**: Don't lose progress
6. **Compare to baseline**: Measure improvement
7. **Test multiple optimizers**: Find best fit
## Resources
- **Paper**: "DSPy: Compiling Declarative Language Model Calls into Self-Improving Pipelines"
- **GitHub**: https://github.com/stanfordnlp/dspy
- **Discord**: https://discord.gg/XCGy2WDCQB

View file

@ -0,0 +1,3 @@
---
description: Fine-tuning, RLHF/DPO/GRPO training, distributed training frameworks, and optimization tools for training LLMs and other models.
---

View file

@ -0,0 +1,335 @@
---
name: huggingface-accelerate
description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [accelerate, torch, transformers]
metadata:
hermes:
tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple]
---
# HuggingFace Accelerate - Unified Distributed Training
## Quick start
Accelerate simplifies distributed training to 4 lines of code.
**Installation**:
```bash
pip install accelerate
```
**Convert PyTorch script** (4 lines):
```python
import torch
+ from accelerate import Accelerator
+ accelerator = Accelerator()
model = torch.nn.Transformer()
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset)
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
- loss.backward()
+ accelerator.backward(loss)
optimizer.step()
```
**Run** (single command):
```bash
accelerate launch train.py
```
## Common workflows
### Workflow 1: From single GPU to multi-GPU
**Original script**:
```python
# train.py
import torch
model = torch.nn.Linear(10, 2).to('cuda')
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
for epoch in range(10):
for batch in dataloader:
batch = batch.to('cuda')
optimizer.zero_grad()
loss = model(batch).mean()
loss.backward()
optimizer.step()
```
**With Accelerate** (4 lines added):
```python
# train.py
import torch
from accelerate import Accelerator # +1
accelerator = Accelerator() # +2
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3
for epoch in range(10):
for batch in dataloader:
# No .to('cuda') needed - automatic!
optimizer.zero_grad()
loss = model(batch).mean()
accelerator.backward(loss) # +4
optimizer.step()
```
**Configure** (interactive):
```bash
accelerate config
```
**Questions**:
- Which machine? (single/multi GPU/TPU/CPU)
- How many machines? (1)
- Mixed precision? (no/fp16/bf16/fp8)
- DeepSpeed? (no/yes)
**Launch** (works on any setup):
```bash
# Single GPU
accelerate launch train.py
# Multi-GPU (8 GPUs)
accelerate launch --multi_gpu --num_processes 8 train.py
# Multi-node
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 0 \
--main_process_ip $MASTER_ADDR \
train.py
```
### Workflow 2: Mixed precision training
**Enable FP16/BF16**:
```python
from accelerate import Accelerator
# FP16 (with gradient scaling)
accelerator = Accelerator(mixed_precision='fp16')
# BF16 (no scaling, more stable)
accelerator = Accelerator(mixed_precision='bf16')
# FP8 (H100+)
accelerator = Accelerator(mixed_precision='fp8')
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
# Everything else is automatic!
for batch in dataloader:
with accelerator.autocast(): # Optional, done automatically
loss = model(batch)
accelerator.backward(loss)
```
### Workflow 3: DeepSpeed ZeRO integration
**Enable DeepSpeed ZeRO-2**:
```python
from accelerate import Accelerator
accelerator = Accelerator(
mixed_precision='bf16',
deepspeed_plugin={
"zero_stage": 2, # ZeRO-2
"offload_optimizer": False,
"gradient_accumulation_steps": 4
}
)
# Same code as before!
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
```
**Or via config**:
```bash
accelerate config
# Select: DeepSpeed → ZeRO-2
```
**deepspeed_config.json**:
```json
{
"fp16": {"enabled": false},
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {"device": "cpu"},
"allgather_bucket_size": 5e8,
"reduce_bucket_size": 5e8
}
}
```
**Launch**:
```bash
accelerate launch --config_file deepspeed_config.json train.py
```
### Workflow 4: FSDP (Fully Sharded Data Parallel)
**Enable FSDP**:
```python
from accelerate import Accelerator, FullyShardedDataParallelPlugin
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
auto_wrap_policy="TRANSFORMER_AUTO_WRAP",
cpu_offload=False
)
accelerator = Accelerator(
mixed_precision='bf16',
fsdp_plugin=fsdp_plugin
)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
```
**Or via config**:
```bash
accelerate config
# Select: FSDP → Full Shard → No CPU Offload
```
### Workflow 5: Gradient accumulation
**Accumulate gradients**:
```python
from accelerate import Accelerator
accelerator = Accelerator(gradient_accumulation_steps=4)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for batch in dataloader:
with accelerator.accumulate(model): # Handles accumulation
optimizer.zero_grad()
loss = model(batch)
accelerator.backward(loss)
optimizer.step()
```
**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps`
## When to use vs alternatives
**Use Accelerate when**:
- Want simplest distributed training
- Need single script for any hardware
- Use HuggingFace ecosystem
- Want flexibility (DDP/DeepSpeed/FSDP/Megatron)
- Need quick prototyping
**Key advantages**:
- **4 lines**: Minimal code changes
- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron
- **Automatic**: Device placement, mixed precision, sharding
- **Interactive config**: No manual launcher setup
- **Single launch**: Works everywhere
**Use alternatives instead**:
- **PyTorch Lightning**: Need callbacks, high-level abstractions
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
- **DeepSpeed**: Direct API control, advanced features
- **Raw DDP**: Maximum control, minimal abstraction
## Common issues
**Issue: Wrong device placement**
Don't manually move to device:
```python
# WRONG
batch = batch.to('cuda')
# CORRECT
# Accelerate handles it automatically after prepare()
```
**Issue: Gradient accumulation not working**
Use context manager:
```python
# CORRECT
with accelerator.accumulate(model):
optimizer.zero_grad()
accelerator.backward(loss)
optimizer.step()
```
**Issue: Checkpointing in distributed**
Use accelerator methods:
```python
# Save only on main process
if accelerator.is_main_process:
accelerator.save_state('checkpoint/')
# Load on all processes
accelerator.load_state('checkpoint/')
```
**Issue: Different results with FSDP**
Ensure same random seed:
```python
from accelerate.utils import set_seed
set_seed(42)
```
## Advanced topics
**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup.
**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration.
**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices.
## Hardware requirements
- **CPU**: Works (slow)
- **Single GPU**: Works
- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP
- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron
- **TPU**: Supported
- **Apple MPS**: Supported
**Launcher requirements**:
- **DDP**: `torch.distributed.run` (built-in)
- **DeepSpeed**: `deepspeed` (pip install deepspeed)
- **FSDP**: PyTorch 1.12+ (built-in)
- **Megatron**: Custom setup
## Resources
- Docs: https://huggingface.co/docs/accelerate
- GitHub: https://github.com/huggingface/accelerate
- Version: 1.11.0+
- Tutorial: "Accelerate your scripts"
- Examples: https://github.com/huggingface/accelerate/tree/main/examples
- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries

View file

@ -0,0 +1,453 @@
# Custom Plugins for Accelerate
## Overview
Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed).
## Plugin Architecture
### Base Plugin Structure
```python
from accelerate.utils import DistributedDataParallelKwargs
from dataclasses import dataclass
@dataclass
class CustomPlugin:
"""Custom training plugin."""
# Plugin configuration
param1: int = 1
param2: str = "default"
def __post_init__(self):
# Validation logic
if self.param1 < 1:
raise ValueError("param1 must be >= 1")
```
### Using Custom Plugin
```python
from accelerate import Accelerator
# Create plugin
custom_plugin = CustomPlugin(param1=4, param2="value")
# Pass to Accelerator
accelerator = Accelerator(
custom_plugin=custom_plugin # Not a real parameter, example only
)
```
## Built-In Plugin Examples
### 1. GradScalerKwargs (FP16 Configuration)
```python
from accelerate.utils import GradScalerKwargs
# Configure gradient scaler for FP16
scaler_kwargs = GradScalerKwargs(
init_scale=2.**16, # Initial loss scale
growth_factor=2.0, # Scale growth rate
backoff_factor=0.5, # Scale backoff rate
growth_interval=2000, # Steps between scale increases
enabled=True # Enable scaler
)
accelerator = Accelerator(
mixed_precision='fp16',
kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler
)
```
**Use case**: Fine-tune FP16 gradient scaling behavior
### 2. DistributedDataParallelKwargs
```python
from accelerate.utils import DistributedDataParallelKwargs
# Configure DDP behavior
ddp_kwargs = DistributedDataParallelKwargs(
bucket_cap_mb=25, # Gradient bucketing size
find_unused_parameters=False, # Find unused params (slower)
check_reduction=False, # Check gradient reduction
gradient_as_bucket_view=True, # Memory optimization
static_graph=False # Static computation graph
)
accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs]
)
```
**Use case**: Optimize DDP performance for specific models
### 3. FP8RecipeKwargs (H100 FP8)
```python
from accelerate.utils import FP8RecipeKwargs
# Configure FP8 training (H100)
fp8_recipe = FP8RecipeKwargs(
backend="te", # TransformerEngine backend
margin=0, # Scaling margin
interval=1, # Scaling interval
fp8_format="HYBRID", # E4M3 + E5M2 hybrid
amax_history_len=1024, # AMAX history length
amax_compute_algo="max" # AMAX computation algorithm
)
accelerator = Accelerator(
mixed_precision='fp8',
kwargs_handlers=[fp8_recipe]
)
```
**Use case**: Ultra-fast training on H100 GPUs
## Custom DeepSpeed Configuration
### ZeRO-3 with CPU Offload
```python
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
# Custom DeepSpeed config
ds_plugin = DeepSpeedPlugin(
zero_stage=3, # ZeRO-3
offload_optimizer_device="cpu", # CPU offload optimizer
offload_param_device="cpu", # CPU offload parameters
zero3_init_flag=True, # ZeRO-3 initialization
zero3_save_16bit_model=True, # Save FP16 weights
)
accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
mixed_precision='bf16'
)
```
### ZeRO-2 with NVMe Offload
```python
ds_plugin = DeepSpeedPlugin(
zero_stage=2,
offload_optimizer_device="nvme", # NVMe offload
offload_param_device="nvme",
nvme_path="/local_nvme", # NVMe mount path
)
```
### Custom JSON Config
```python
import json
# Load custom DeepSpeed config
with open('deepspeed_config.json', 'r') as f:
ds_config = json.load(f)
ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config)
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
```
**Example config** (`deepspeed_config.json`):
```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true
},
"steps_per_print": 100,
"wall_clock_breakdown": false
}
```
## Custom FSDP Configuration
### FSDP with Custom Auto-Wrap Policy
```python
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import functools
# Custom wrap policy (size-based)
wrap_policy = functools.partial(
size_based_auto_wrap_policy,
min_num_params=1e6 # Wrap layers with 1M+ params
)
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy
mixed_precision_policy=None, # Use Accelerator's mixed precision
auto_wrap_policy=wrap_policy, # Custom wrapping
cpu_offload=False,
ignored_modules=None, # Modules to not wrap
state_dict_type="FULL_STATE_DICT", # Save format
optim_state_dict_config=None,
limit_all_gathers=False,
use_orig_params=True, # Use original param shapes
)
accelerator = Accelerator(
fsdp_plugin=fsdp_plugin,
mixed_precision='bf16'
)
```
### FSDP with Transformer Auto-Wrap
```python
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
# Wrap at transformer block level
wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers
)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrap_policy
)
```
## Creating Custom Training Strategy
### Example: Custom Gradient Accumulation
```python
from accelerate import Accelerator
class CustomGradientAccumulation:
def __init__(self, steps=4, adaptive=False):
self.steps = steps
self.adaptive = adaptive
self.current_step = 0
def should_sync(self, loss):
"""Decide whether to sync gradients."""
self.current_step += 1
# Adaptive: sync on high loss
if self.adaptive and loss > threshold:
self.current_step = 0
return True
# Regular: sync every N steps
if self.current_step >= self.steps:
self.current_step = 0
return True
return False
# Usage
custom_accum = CustomGradientAccumulation(steps=8, adaptive=True)
accelerator = Accelerator()
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
# Scale loss
loss = loss / custom_accum.steps
accelerator.backward(loss)
# Conditional sync
if custom_accum.should_sync(loss.item()):
optimizer.step()
optimizer.zero_grad()
```
### Example: Custom Mixed Precision
```python
import torch
class CustomMixedPrecision:
"""Custom mixed precision with dynamic loss scaling."""
def __init__(self, init_scale=2**16, scale_window=2000):
self.scaler = torch.cuda.amp.GradScaler(
init_scale=init_scale,
growth_interval=scale_window
)
self.scale_history = []
def scale_loss(self, loss):
"""Scale loss for backward."""
return self.scaler.scale(loss)
def unscale_and_clip(self, optimizer, max_norm=1.0):
"""Unscale gradients and clip."""
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
optimizer.param_groups[0]['params'],
max_norm
)
def step(self, optimizer):
"""Optimizer step with scaler update."""
scale_before = self.scaler.get_scale()
self.scaler.step(optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
# Track scale changes
if scale_before != scale_after:
self.scale_history.append(scale_after)
# Usage
custom_mp = CustomMixedPrecision()
for batch in dataloader:
with torch.cuda.amp.autocast(dtype=torch.float16):
loss = model(**batch).loss
scaled_loss = custom_mp.scale_loss(loss)
scaled_loss.backward()
custom_mp.unscale_and_clip(optimizer, max_norm=1.0)
custom_mp.step(optimizer)
optimizer.zero_grad()
```
## Advanced: Custom Distributed Backend
### Custom AllReduce Strategy
```python
import torch.distributed as dist
class CustomAllReduce:
"""Custom all-reduce with compression."""
def __init__(self, compression_ratio=0.1):
self.compression_ratio = compression_ratio
def compress_gradients(self, tensor):
"""Top-k gradient compression."""
k = int(tensor.numel() * self.compression_ratio)
values, indices = torch.topk(tensor.abs().view(-1), k)
return values, indices
def all_reduce_compressed(self, tensor):
"""All-reduce with gradient compression."""
# Compress
values, indices = self.compress_gradients(tensor)
# All-reduce compressed gradients
dist.all_reduce(values, op=dist.ReduceOp.SUM)
# Decompress
tensor_compressed = torch.zeros_like(tensor).view(-1)
tensor_compressed[indices] = values / dist.get_world_size()
return tensor_compressed.view_as(tensor)
# Usage in training loop
custom_ar = CustomAllReduce(compression_ratio=0.1)
for batch in dataloader:
loss = model(**batch).loss
loss.backward()
# Custom all-reduce
for param in model.parameters():
if param.grad is not None:
param.grad.data = custom_ar.all_reduce_compressed(param.grad.data)
optimizer.step()
optimizer.zero_grad()
```
## Plugin Best Practices
### 1. Validation in `__post_init__`
```python
@dataclass
class CustomPlugin:
learning_rate: float = 1e-3
warmup_steps: int = 1000
def __post_init__(self):
# Validate parameters
if self.learning_rate <= 0:
raise ValueError("learning_rate must be positive")
if self.warmup_steps < 0:
raise ValueError("warmup_steps must be non-negative")
# Compute derived values
self.min_lr = self.learning_rate * 0.1
```
### 2. Compatibility Checks
```python
@dataclass
class CustomPlugin:
feature_enabled: bool = True
def is_compatible(self, accelerator):
"""Check if plugin is compatible with accelerator config."""
if self.feature_enabled and accelerator.mixed_precision == 'fp8':
raise ValueError("Custom plugin not compatible with FP8")
return True
```
### 3. State Management
```python
@dataclass
class CustomPlugin:
counter: int = 0
history: list = None
def __post_init__(self):
if self.history is None:
self.history = []
def update_state(self, value):
"""Update plugin state during training."""
self.counter += 1
self.history.append(value)
```
## Resources
- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs
- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/
- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html
- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu

View file

@ -0,0 +1,489 @@
# Megatron Integration with Accelerate
## Overview
Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism.
**Megatron capabilities**:
- **Tensor Parallelism (TP)**: Split layers across GPUs
- **Pipeline Parallelism (PP)**: Split model depth across GPUs
- **Data Parallelism (DP)**: Replicate model across GPU groups
- **Sequence Parallelism**: Split sequences for long contexts
## Setup
### Install Megatron-LM
```bash
# Clone Megatron-LM repository
git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM
pip install -e .
# Install Apex (NVIDIA optimizations)
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
--config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
```
### Accelerate Configuration
```bash
accelerate config
```
**Questions**:
```
In which compute environment are you running?
> This machine
Which type of machine are you using?
> Multi-GPU
How many different machines will you use?
> 1
Do you want to use DeepSpeed/FSDP?
> No
Do you want to use Megatron-LM?
> Yes
What is the Tensor Parallelism degree? [1-8]
> 2
Do you want to enable Sequence Parallelism?
> No
What is the Pipeline Parallelism degree? [1-8]
> 2
What is the Data Parallelism degree? [1-8]
> 2
Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE']
> SELECTIVE
Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM']
> SEQUENTIAL
```
**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`):
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: MEGATRON_LM
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
megatron_lm_config:
megatron_lm_gradient_clipping: 1.0
megatron_lm_learning_rate_decay_iters: 320000
megatron_lm_num_micro_batches: 1
megatron_lm_pp_degree: 2
megatron_lm_recompute_activations: true
megatron_lm_sequence_parallelism: false
megatron_lm_tp_degree: 2
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
## Parallelism Strategies
### Tensor Parallelism (TP)
**Splits each transformer layer across GPUs**:
```python
# Layer split across 2 GPUs
# GPU 0: First half of attention heads
# GPU 1: Second half of attention heads
# Each GPU computes partial outputs
# All-reduce combines results
```
**TP degree recommendations**:
- **TP=1**: No tensor parallelism (single GPU per layer)
- **TP=2**: 2 GPUs per layer (good for 7-13B models)
- **TP=4**: 4 GPUs per layer (good for 20-40B models)
- **TP=8**: 8 GPUs per layer (good for 70B+ models)
**Benefits**:
- Reduces memory per GPU
- All-reduce communication (fast)
**Drawbacks**:
- Requires fast inter-GPU bandwidth (NVLink)
- Communication overhead per layer
### Pipeline Parallelism (PP)
**Splits model depth across GPUs**:
```python
# 12-layer model, PP=4
# GPU 0: Layers 0-2
# GPU 1: Layers 3-5
# GPU 2: Layers 6-8
# GPU 3: Layers 9-11
```
**PP degree recommendations**:
- **PP=1**: No pipeline parallelism
- **PP=2**: 2 pipeline stages (good for 20-40B models)
- **PP=4**: 4 pipeline stages (good for 70B+ models)
- **PP=8**: 8 pipeline stages (good for 175B+ models)
**Benefits**:
- Linear memory reduction (4× PP = 4× less memory)
- Works across nodes (slower interconnect OK)
**Drawbacks**:
- Pipeline bubbles (idle time)
- Requires micro-batching
### Data Parallelism (DP)
**Replicates model across GPU groups**:
```python
# 8 GPUs, TP=2, PP=2, DP=2
# Group 0 (GPUs 0-3): Full model replica
# Group 1 (GPUs 4-7): Full model replica
```
**DP degree**:
- `DP = total_gpus / (TP × PP)`
- Example: 8 GPUs, TP=2, PP=2 → DP=2
**Benefits**:
- Increases throughput
- Scales batch size
### Sequence Parallelism
**Splits long sequences across GPUs** (extends TP):
```python
# 8K sequence, TP=2, Sequence Parallel=True
# GPU 0: Tokens 0-4095
# GPU 1: Tokens 4096-8191
```
**Benefits**:
- Enables very long sequences (100K+ tokens)
- Reduces activation memory
**Requirements**:
- Must use with TP > 1
- RoPE/ALiBi position encodings work best
## Accelerate Code Example
### Basic Setup
```python
from accelerate import Accelerator
from accelerate.utils import MegatronLMPlugin
# Configure Megatron
megatron_plugin = MegatronLMPlugin(
tp_degree=2, # Tensor parallelism degree
pp_degree=2, # Pipeline parallelism degree
num_micro_batches=4, # Micro-batches for pipeline
gradient_clipping=1.0, # Gradient clipping value
sequence_parallelism=False, # Enable sequence parallelism
recompute_activations=True, # Activation checkpointing
use_distributed_optimizer=True, # Distributed optimizer
custom_prepare_model_function=None, # Custom model prep
)
# Initialize accelerator
accelerator = Accelerator(
mixed_precision='bf16',
megatron_lm_plugin=megatron_plugin
)
# Prepare model and optimizer
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
)
# Training loop (same as DDP!)
for batch in train_dataloader:
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
```
### Full Training Script
```python
import torch
from accelerate import Accelerator
from accelerate.utils import MegatronLMPlugin
from transformers import GPT2Config, GPT2LMHeadModel
def main():
# Megatron configuration
megatron_plugin = MegatronLMPlugin(
tp_degree=2,
pp_degree=2,
num_micro_batches=4,
gradient_clipping=1.0,
)
accelerator = Accelerator(
mixed_precision='bf16',
gradient_accumulation_steps=8,
megatron_lm_plugin=megatron_plugin
)
# Model
config = GPT2Config(
n_layer=24,
n_head=16,
n_embd=1024,
)
model = GPT2LMHeadModel(config)
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
# Prepare
model, optimizer, train_loader = accelerator.prepare(
model, optimizer, train_loader
)
# Training loop
for epoch in range(num_epochs):
for batch in train_loader:
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Save checkpoint
accelerator.wait_for_everyone()
accelerator.save_state(f'checkpoint-epoch-{epoch}')
if __name__ == '__main__':
main()
```
### Launch Command
```bash
# 8 GPUs, TP=2, PP=2, DP=2
accelerate launch --multi_gpu --num_processes 8 train.py
# Multi-node (2 nodes, 8 GPUs each)
# Node 0
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 0 \
--main_process_ip $MASTER_ADDR \
--main_process_port 29500 \
train.py
# Node 1
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 1 \
--main_process_ip $MASTER_ADDR \
--main_process_port 29500 \
train.py
```
## Activation Checkpointing
**Reduces memory by recomputing activations**:
```python
megatron_plugin = MegatronLMPlugin(
recompute_activations=True, # Enable checkpointing
checkpoint_num_layers=1, # Checkpoint every N layers
distribute_checkpointed_activations=True, # Distribute across TP
partition_activations=True, # Partition in PP
check_for_nan_in_loss_and_grad=True, # Stability check
)
```
**Strategies**:
- `SELECTIVE`: Checkpoint transformer blocks only
- `FULL`: Checkpoint all layers
- `NONE`: No checkpointing
**Memory savings**: 30-50% with 10-15% slowdown
## Distributed Optimizer
**Shards optimizer state across DP ranks**:
```python
megatron_plugin = MegatronLMPlugin(
use_distributed_optimizer=True, # Enable sharded optimizer
)
```
**Benefits**:
- Reduces optimizer memory by DP degree
- Example: DP=4 → 4× less optimizer memory per GPU
**Compatible with**:
- AdamW, Adam, SGD
- Mixed precision training
## Performance Tuning
### Micro-Batch Size
```python
# Pipeline parallelism requires micro-batching
megatron_plugin = MegatronLMPlugin(
pp_degree=4,
num_micro_batches=16, # 16 micro-batches per pipeline
)
# Effective batch = num_micro_batches × micro_batch_size × DP
# Example: 16 × 2 × 4 = 128
```
**Recommendations**:
- More micro-batches → less pipeline bubble
- Typical: 4-16 micro-batches
### Sequence Length
```python
# For long sequences, enable sequence parallelism
megatron_plugin = MegatronLMPlugin(
tp_degree=4,
sequence_parallelism=True, # Required: TP > 1
)
# Enables sequences up to TP × normal limit
# Example: TP=4, 8K normal → 32K with sequence parallel
```
### GPU Topology
**NVLink required for TP**:
```bash
# Check NVLink topology
nvidia-smi topo -m
# Good topology (NVLink between all GPUs)
# GPU0 - GPU1: NV12 (fast)
# GPU0 - GPU2: NV12 (fast)
# Bad topology (PCIe only)
# GPU0 - GPU4: PHB (slow, avoid TP across these)
```
**Recommendations**:
- **TP**: Within same node (NVLink)
- **PP**: Across nodes (slower interconnect OK)
- **DP**: Any topology
## Model Size Guidelines
| Model Size | GPUs | TP | PP | DP | Micro-Batches |
|------------|------|----|----|----|--------------|
| 7B | 8 | 1 | 1 | 8 | 1 |
| 13B | 8 | 2 | 1 | 4 | 1 |
| 20B | 16 | 4 | 1 | 4 | 1 |
| 40B | 32 | 4 | 2 | 4 | 4 |
| 70B | 64 | 8 | 2 | 4 | 8 |
| 175B | 128 | 8 | 4 | 4 | 16 |
**Assumptions**: BF16, 2K sequence length, A100 80GB
## Checkpointing
### Save Checkpoint
```python
# Save full model state
accelerator.save_state('checkpoint-1000')
# Megatron saves separate files per rank
# checkpoint-1000/
# pytorch_model_tp_0_pp_0.bin
# pytorch_model_tp_0_pp_1.bin
# pytorch_model_tp_1_pp_0.bin
# pytorch_model_tp_1_pp_1.bin
# optimizer_tp_0_pp_0.bin
# ...
```
### Load Checkpoint
```python
# Resume training
accelerator.load_state('checkpoint-1000')
# Automatically loads correct shard per rank
```
### Convert to Standard PyTorch
```bash
# Merge Megatron checkpoint to single file
python merge_megatron_checkpoint.py \
--checkpoint-dir checkpoint-1000 \
--output pytorch_model.bin
```
## Common Issues
### Issue: OOM with Pipeline Parallelism
**Solution**: Increase micro-batches
```python
megatron_plugin = MegatronLMPlugin(
pp_degree=4,
num_micro_batches=16, # Increase from 4
)
```
### Issue: Slow Training
**Check 1**: Pipeline bubbles (PP too high)
```python
# Reduce PP, increase TP
tp_degree=4 # Increase
pp_degree=2 # Decrease
```
**Check 2**: Micro-batch size too small
```python
num_micro_batches=8 # Increase
```
### Issue: NVLink Not Detected
```bash
# Verify NVLink
nvidia-smi nvlink -s
# If no NVLink, avoid TP > 1
# Use PP or DP instead
```
## Resources
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm
- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism"
- NVIDIA Apex: https://github.com/NVIDIA/apex

View file

@ -0,0 +1,525 @@
# Accelerate Performance Tuning
## Profiling
### Basic Profiling
```python
from accelerate import Accelerator
import time
accelerator = Accelerator()
# Warmup
for _ in range(10):
batch = next(iter(dataloader))
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Profile training loop
start = time.time()
total_batches = 100
for i, batch in enumerate(dataloader):
if i >= total_batches:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
accelerator.wait_for_everyone() # Sync all processes
elapsed = time.time() - start
# Metrics
batches_per_sec = total_batches / elapsed
samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed
print(f"Throughput: {samples_per_sec:.2f} samples/sec")
print(f"Batches/sec: {batches_per_sec:.2f}")
```
### PyTorch Profiler Integration
```python
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for i, batch in enumerate(dataloader):
if i >= 10: # Profile first 10 batches
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Print profiling results
print(prof.key_averages().table(
sort_by="cuda_time_total", row_limit=20
))
# Export to Chrome tracing
prof.export_chrome_trace("trace.json")
# View at chrome://tracing
```
## Memory Optimization
### 1. Gradient Accumulation
**Problem**: Large batch size causes OOM
**Solution**: Accumulate gradients across micro-batches
```python
accelerator = Accelerator(gradient_accumulation_steps=8)
# Effective batch = batch_size × accumulation_steps × num_gpus
# Example: 4 × 8 × 8 = 256
for batch in dataloader:
with accelerator.accumulate(model): # Handles accumulation logic
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
```
**Memory savings**: 8× less activation memory (with 8 accumulation steps)
### 2. Gradient Checkpointing
**Enable in model**:
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
use_cache=False # Required for gradient checkpointing
)
# Enable checkpointing
model.gradient_checkpointing_enable()
# Prepare with Accelerate
model = accelerator.prepare(model)
```
**Memory savings**: 30-50% with 10-15% slowdown
### 3. Mixed Precision
**BF16 (A100/H100)**:
```python
accelerator = Accelerator(mixed_precision='bf16')
# Automatic mixed precision
for batch in dataloader:
outputs = model(**batch) # Forward in BF16
loss = outputs.loss
accelerator.backward(loss) # Backward in FP32
optimizer.step()
```
**FP16 (V100, older GPUs)**:
```python
from accelerate.utils import GradScalerKwargs
scaler_kwargs = GradScalerKwargs(
init_scale=2.**16,
growth_interval=2000
)
accelerator = Accelerator(
mixed_precision='fp16',
kwargs_handlers=[scaler_kwargs]
)
```
**Memory savings**: 50% compared to FP32
### 4. CPU Offloading (DeepSpeed)
```python
from accelerate.utils import DeepSpeedPlugin
ds_plugin = DeepSpeedPlugin(
zero_stage=3,
offload_optimizer_device="cpu", # Offload optimizer to CPU
offload_param_device="cpu", # Offload parameters to CPU
)
accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
mixed_precision='bf16'
)
```
**Memory savings**: 10-20× for optimizer state, 5-10× for parameters
**Trade-off**: 20-30% slower due to CPU-GPU transfers
### 5. Flash Attention
```python
# Install flash-attn
# pip install flash-attn
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
attn_implementation="flash_attention_2" # Enable Flash Attention 2
)
model = accelerator.prepare(model)
```
**Memory savings**: 50% for attention, 2× faster
**Requirements**: A100/H100, sequence length must be multiple of 128
## Communication Optimization
### 1. Gradient Bucketing (DDP)
```python
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(
bucket_cap_mb=25, # Bucket size for gradient reduction
gradient_as_bucket_view=True, # Reduce memory copies
static_graph=False # Set True if model doesn't change
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
```
**Recommended bucket sizes**:
- Small models (<1B): 25 MB
- Medium models (1-10B): 50-100 MB
- Large models (>10B): 100-200 MB
### 2. Find Unused Parameters
```python
# Only enable if model has unused parameters (slower!)
ddp_kwargs = DistributedDataParallelKwargs(
find_unused_parameters=True
)
```
**Use case**: Models with conditional branches (e.g., mixture of experts)
**Cost**: 10-20% slower
### 3. NCCL Tuning
```bash
# Set environment variables before launch
export NCCL_DEBUG=INFO # Debug info
export NCCL_IB_DISABLE=0 # Enable InfiniBand
export NCCL_SOCKET_IFNAME=eth0 # Network interface
export NCCL_P2P_LEVEL=NVL # Use NVLink
accelerate launch train.py
```
**NCCL_P2P_LEVEL options**:
- `NVL`: NVLink (fastest, within node)
- `PIX`: PCIe (fast, within node)
- `PHB`: PCIe host bridge (slow, cross-node)
## Data Loading Optimization
### 1. DataLoader Workers
```python
from torch.utils.data import DataLoader
train_loader = DataLoader(
dataset,
batch_size=32,
num_workers=4, # Parallel data loading
pin_memory=True, # Pin memory for faster GPU transfer
prefetch_factor=2, # Prefetch batches per worker
persistent_workers=True # Keep workers alive between epochs
)
train_loader = accelerator.prepare(train_loader)
```
**Recommendations**:
- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers)
- `pin_memory`: Always True for GPU training
- `prefetch_factor`: 2-4 (higher for slow data loading)
### 2. Data Preprocessing
```python
from datasets import load_dataset
# Bad: Preprocess during training (slow)
dataset = load_dataset("openwebtext")
for batch in dataset:
tokens = tokenizer(batch['text']) # Slow!
...
# Good: Preprocess once, save
dataset = load_dataset("openwebtext")
tokenized = dataset.map(
lambda x: tokenizer(x['text']),
batched=True,
num_proc=8, # Parallel preprocessing
remove_columns=['text']
)
tokenized.save_to_disk("preprocessed_data")
# Load preprocessed
dataset = load_from_disk("preprocessed_data")
```
### 3. Faster Tokenization
```python
import os
# Enable Rust-based tokenizers (10× faster)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"gpt2",
use_fast=True # Use fast Rust tokenizer
)
```
## Compilation (PyTorch 2.0+)
### Compile Model
```python
import torch
# Compile model for faster execution
model = torch.compile(
model,
mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune
fullgraph=False, # Compile entire graph (stricter)
dynamic=True # Support dynamic shapes
)
model = accelerator.prepare(model)
```
**Speedup**: 10-50% depending on model
**Compilation modes**:
- `default`: Balanced (best for most cases)
- `reduce-overhead`: Min overhead (best for small batches)
- `max-autotune`: Max performance (slow compile, best for production)
### Compilation Best Practices
```python
# Bad: Compile after prepare (won't work)
model = accelerator.prepare(model)
model = torch.compile(model) # Error!
# Good: Compile before prepare
model = torch.compile(model)
model = accelerator.prepare(model)
# Training loop
for batch in dataloader:
# First iteration: slow (compilation)
# Subsequent iterations: fast (compiled)
outputs = model(**batch)
...
```
## Benchmarking Different Strategies
### Script Template
```python
import time
import torch
from accelerate import Accelerator
def benchmark_strategy(strategy_name, accelerator_kwargs):
"""Benchmark a specific training strategy."""
accelerator = Accelerator(**accelerator_kwargs)
# Setup
model = create_model()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
dataloader = create_dataloader()
model, optimizer, dataloader = accelerator.prepare(
model, optimizer, dataloader
)
# Warmup
for i, batch in enumerate(dataloader):
if i >= 10:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Benchmark
accelerator.wait_for_everyone()
torch.cuda.synchronize()
start = time.time()
num_batches = 100
for i, batch in enumerate(dataloader):
if i >= num_batches:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
accelerator.wait_for_everyone()
torch.cuda.synchronize()
elapsed = time.time() - start
# Metrics
throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed
memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB
if accelerator.is_main_process:
print(f"\n{strategy_name}:")
print(f" Throughput: {throughput:.2f} samples/sec")
print(f" Memory: {memory_used:.2f} GB")
print(f" Time: {elapsed:.2f} sec")
torch.cuda.reset_peak_memory_stats()
# Benchmark different strategies
strategies = [
("DDP + FP32", {}),
("DDP + BF16", {"mixed_precision": "bf16"}),
("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}),
("FSDP", {"fsdp_plugin": fsdp_plugin}),
("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}),
("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}),
]
for name, kwargs in strategies:
benchmark_strategy(name, kwargs)
```
## Performance Checklist
**Before training**:
- [ ] Use BF16/FP16 mixed precision
- [ ] Enable gradient checkpointing (if OOM)
- [ ] Set appropriate `num_workers` (2-4 per GPU)
- [ ] Enable `pin_memory=True`
- [ ] Preprocess data once, not during training
- [ ] Compile model with `torch.compile` (PyTorch 2.0+)
**For large models**:
- [ ] Use FSDP or DeepSpeed ZeRO-3
- [ ] Enable CPU offloading (if still OOM)
- [ ] Use Flash Attention
- [ ] Increase gradient accumulation
**For multi-node**:
- [ ] Check network topology (InfiniBand > Ethernet)
- [ ] Tune NCCL settings
- [ ] Use larger bucket sizes for DDP
- [ ] Verify NVLink for tensor parallelism
**Profiling**:
- [ ] Profile first 10-100 batches
- [ ] Check GPU utilization (`nvidia-smi dmon`)
- [ ] Check data loading time (should be <5% of iteration)
- [ ] Identify communication bottlenecks
## Common Performance Issues
### Issue: Low GPU Utilization (<80%)
**Cause 1**: Data loading bottleneck
```python
# Solution: Increase workers and prefetch
num_workers=8
prefetch_factor=4
```
**Cause 2**: Small batch size
```python
# Solution: Increase batch size or use gradient accumulation
batch_size=32 # Increase
gradient_accumulation_steps=4 # Or accumulate
```
### Issue: High Memory Usage
**Solution 1**: Gradient checkpointing
```python
model.gradient_checkpointing_enable()
```
**Solution 2**: Reduce batch size, increase accumulation
```python
batch_size=8 # Reduce from 32
gradient_accumulation_steps=16 # Maintain effective batch
```
**Solution 3**: Use FSDP or DeepSpeed ZeRO-3
```python
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
```
### Issue: Slow Multi-GPU Training
**Cause**: Communication bottleneck
**Check 1**: Gradient bucket size
```python
ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100)
```
**Check 2**: NCCL settings
```bash
export NCCL_DEBUG=INFO
# Check for "Using NVLS" (good) vs "Using PHB" (bad)
```
**Check 3**: Network bandwidth
```bash
# Test inter-GPU bandwidth
nvidia-smi nvlink -s
```
## Resources
- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance
- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
- Flash Attention: https://github.com/Dao-AILab/flash-attention

View file

@ -0,0 +1,161 @@
---
name: axolotl
description: Expert guidance for fine-tuning LLMs with Axolotl - YAML configs, 100+ models, LoRA/QLoRA, DPO/KTO/ORPO/GRPO, multimodal support
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [axolotl, torch, transformers, datasets, peft, accelerate, deepspeed]
metadata:
hermes:
tags: [Fine-Tuning, Axolotl, LLM, LoRA, QLoRA, DPO, KTO, ORPO, GRPO, YAML, HuggingFace, DeepSpeed, Multimodal]
---
# Axolotl Skill
Comprehensive assistance with axolotl development, generated from official documentation.
## When to Use This Skill
This skill should be triggered when:
- Working with axolotl
- Asking about axolotl features or APIs
- Implementing axolotl solutions
- Debugging axolotl code
- Learning axolotl best practices
## Quick Reference
### Common Patterns
**Pattern 1:** To validate that acceptable data transfer speeds exist for your training job, running NCCL Tests can help pinpoint bottlenecks, for example:
```
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
```
**Pattern 2:** Configure your model to use FSDP in the Axolotl yaml. For example:
```
fsdp_version: 2
fsdp_config:
offload_params: true
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
```
**Pattern 3:** The context_parallel_size should be a divisor of the total number of GPUs. For example:
```
context_parallel_size
```
**Pattern 4:** For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and context_parallel_size=4: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4
```
context_parallel_size=4
```
**Pattern 5:** Setting save_compressed: true in your configuration enables saving models in a compressed format, which: - Reduces disk space usage by approximately 40% - Maintains compatibility with vLLM for accelerated inference - Maintains compatibility with llmcompressor for further optimization (example: quantization)
```
save_compressed: true
```
**Pattern 6:** Note It is not necessary to place your integration in the integrations folder. It can be in any location, so long as its installed in a package in your python env. See this repo for an example: https://github.com/axolotl-ai-cloud/diff-transformer
```
integrations
```
**Pattern 7:** Handle both single-example and batched data. - single example: sample[input_ids] is a list[int] - batched data: sample[input_ids] is a list[list[int]]
```
utils.trainer.drop_long_seq(sample, sequence_len=2048, min_sequence_len=2)
```
### Example Code Patterns
**Example 1** (python):
```python
cli.cloud.modal_.ModalCloud(config, app=None)
```
**Example 2** (python):
```python
cli.cloud.modal_.run_cmd(cmd, run_folder, volumes=None)
```
**Example 3** (python):
```python
core.trainers.base.AxolotlTrainer(
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
)
```
**Example 4** (python):
```python
core.trainers.base.AxolotlTrainer.log(logs, start_time=None)
```
**Example 5** (python):
```python
prompt_strategies.input_output.RawInputOutputPrompter()
```
## Reference Files
This skill includes comprehensive documentation in `references/`:
- **api.md** - Api documentation
- **dataset-formats.md** - Dataset-Formats documentation
- **other.md** - Other documentation
Use `view` to read specific reference files when detailed information is needed.
## Working with This Skill
### For Beginners
Start with the getting_started or tutorials reference files for foundational concepts.
### For Specific Features
Use the appropriate category reference file (api, guides, etc.) for detailed information.
### For Code Examples
The quick reference section above contains common patterns extracted from the official docs.
## Resources
### references/
Organized documentation extracted from official sources. These files contain:
- Detailed explanations
- Code examples with language annotations
- Links to original documentation
- Table of contents for quick navigation
### scripts/
Add helper scripts here for common automation tasks.
### assets/
Add templates, boilerplate, or example projects here.
## Notes
- This skill was automatically generated from official documentation
- Reference files preserve the structure and examples from source docs
- Code examples include language detection for better syntax highlighting
- Quick reference patterns are extracted from common usage examples in the docs
## Updating
To refresh this skill with updated documentation:
1. Re-run the scraper with the same configuration
2. The skill will be rebuilt with the latest information

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,15 @@
# Axolotl Documentation Index
## Categories
### Api
**File:** `api.md`
**Pages:** 150
### Dataset-Formats
**File:** `dataset-formats.md`
**Pages:** 9
### Other
**File:** `other.md`
**Pages:** 26

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,370 @@
---
name: optimizing-attention-flash
description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [flash-attn, torch, transformers]
metadata:
hermes:
tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
---
# Flash Attention - Fast Memory-Efficient Attention
## Quick start
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
**PyTorch native (easiest, PyTorch 2.2+)**:
```python
import torch
import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
# Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)
```
**flash-attn library (more features)**:
```bash
pip install flash-attn --no-build-isolation
```
```python
from flash_attn import flash_attn_func
# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
```
## Common workflows
### Workflow 1: Enable in existing PyTorch model
Copy this checklist:
```
Flash Attention Integration:
- [ ] Step 1: Check PyTorch version (≥2.2)
- [ ] Step 2: Enable Flash Attention backend
- [ ] Step 3: Verify speedup with profiling
- [ ] Step 4: Test accuracy matches baseline
```
**Step 1: Check PyTorch version**
```bash
python -c "import torch; print(torch.__version__)"
# Should be ≥2.2.0
```
If <2.2, upgrade:
```bash
pip install --upgrade torch
```
**Step 2: Enable Flash Attention backend**
Replace standard attention:
```python
# Before (standard attention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v
# After (Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
```
Force Flash Attention backend:
```python
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)
```
**Step 3: Verify speedup with profiling**
```python
import torch.utils.benchmark as benchmark
def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ v
# Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
```
Expected: 2-4x speedup for sequences >512 tokens.
**Step 4: Test accuracy matches baseline**
```python
# Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)
# Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v
# Check difference
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
# Should be <1e-3 for float16
```
### Workflow 2: Use flash-attn library for advanced features
For multi-query attention, sliding window, or H100 FP8.
Copy this checklist:
```
flash-attn Library Setup:
- [ ] Step 1: Install flash-attn library
- [ ] Step 2: Modify attention code
- [ ] Step 3: Enable advanced features
- [ ] Step 4: Benchmark performance
```
**Step 1: Install flash-attn library**
```bash
# NVIDIA GPUs (CUDA 12.0+)
pip install flash-attn --no-build-isolation
# Verify installation
python -c "from flash_attn import flash_attn_func; print('Success')"
```
**Step 2: Modify attention code**
```python
from flash_attn import flash_attn_func
# Input: [batch_size, seq_len, num_heads, head_dim]
# Transpose from [batch, heads, seq, dim] if needed
q = q.transpose(1, 2) # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(
q, k, v,
dropout_p=0.1,
causal=True, # For autoregressive models
window_size=(-1, -1), # No sliding window
softmax_scale=None # Auto-scale
)
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
```
**Step 3: Enable advanced features**
Multi-query attention (shared K/V across heads):
```python
from flash_attn import flash_attn_func
# q: [batch, seq, num_q_heads, dim]
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
out = flash_attn_func(q, k, v) # Automatically handles MQA
```
Sliding window attention (local attention):
```python
# Only attend to window of 256 tokens before/after
out = flash_attn_func(
q, k, v,
window_size=(256, 256), # (left, right) window
causal=True
)
```
**Step 4: Benchmark performance**
```python
import torch
from flash_attn import flash_attn_func
import time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Warmup
for _ in range(10):
_ = flash_attn_func(q, k, v)
# Benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = flash_attn_func(q, k, v)
torch.cuda.synchronize()
end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
```
### Workflow 3: H100 FP8 optimization (FlashAttention-3)
For maximum performance on H100 GPUs.
```
FP8 Setup:
- [ ] Step 1: Verify H100 GPU available
- [ ] Step 2: Install flash-attn with FP8 support
- [ ] Step 3: Convert inputs to FP8
- [ ] Step 4: Run with FP8 attention
```
**Step 1: Verify H100 GPU**
```bash
nvidia-smi --query-gpu=name --format=csv
# Should show "H100" or "H800"
```
**Step 2: Install flash-attn with FP8 support**
```bash
pip install flash-attn --no-build-isolation
# FP8 support included for H100
```
**Step 3: Convert inputs to FP8**
```python
import torch
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
# Convert to float8_e4m3 (FP8)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
```
**Step 4: Run with FP8 attention**
```python
from flash_attn import flash_attn_func
# FlashAttention-3 automatically uses FP8 kernels on H100
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
```
## When to use vs alternatives
**Use Flash Attention when:**
- Training transformers with sequences >512 tokens
- Running inference with long context (>2K tokens)
- GPU memory constrained (OOM with standard attention)
- Need 2-4x speedup without accuracy loss
- Using PyTorch 2.2+ or can install flash-attn
**Use alternatives instead:**
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
- **xFormers**: Need more attention variants (not just speed)
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
## Common issues
**Issue: ImportError: cannot import flash_attn**
Install with no-build-isolation flag:
```bash
pip install flash-attn --no-build-isolation
```
Or install CUDA toolkit first:
```bash
conda install cuda -c nvidia
pip install flash-attn --no-build-isolation
```
**Issue: Slower than expected (no speedup)**
Flash Attention benefits increase with sequence length:
- <512 tokens: Minimal speedup (10-20%)
- 512-2K tokens: 2-3x speedup
- >2K tokens: 3-4x speedup
Check sequence length is sufficient.
**Issue: RuntimeError: CUDA error**
Verify GPU supports Flash Attention:
```python
import torch
print(torch.cuda.get_device_capability())
# Should be ≥(7, 5) for Turing+
```
Flash Attention requires:
- Ampere (A100, A10): ✅ Full support
- Turing (T4): ✅ Supported
- Volta (V100): ❌ Not supported
**Issue: Accuracy degradation**
Check dtype is float16 or bfloat16 (not float32):
```python
q = q.to(torch.float16) # Or torch.bfloat16
```
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
## Advanced topics
**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
## Hardware requirements
- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
- **CUDA**: 12.0+ (11.8 minimum)
- **PyTorch**: 2.2+ for native support
**Not supported**: V100 (Volta), CPU inference
## Resources
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
- Blog: https://tridao.me/blog/2024/flash3/
- GitHub: https://github.com/Dao-AILab/flash-attention
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

View file

@ -0,0 +1,215 @@
# Performance Benchmarks
## Contents
- Speed comparisons across GPUs
- Memory usage analysis
- Scaling with sequence length
- Training vs inference performance
- Flash Attention versions comparison
## Speed comparisons across GPUs
### A100 80GB (Ampere)
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|------------|----------|--------------|--------------|---------------|
| 512 | 1.2 | 0.9 | N/A | 1.3x |
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
### H100 80GB (Hopper)
**Forward pass time** (milliseconds, same config):
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|------------|----------|--------------|---------------------|--------------------|--------------|
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
### A10G 24GB (Ampere)
**Forward pass time** (milliseconds, batch=4):
| Seq Length | Standard | Flash Attn 2 | Speedup |
|------------|----------|--------------|---------|
| 512 | 2.1 | 1.6 | 1.3x |
| 1024 | 6.8 | 2.8 | 2.4x |
| 2048 | 25.9 | 9.4 | 2.8x |
| 4096 | 102.1 | 35.2 | 2.9x |
## Memory usage analysis
### GPU memory consumption (batch=8, heads=32, dim=64)
**Standard attention memory**:
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|------------|------------------|----------|-------|-------|
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
**Flash Attention 2 memory**:
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|------------|---------------------|----------|-------|-----------|
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
### Memory scaling comparison
**Llama 2 7B model memory** (float16, batch=1):
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|----------------|-------------------|-------------------|-------------------|
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
| 32K | OOM | 14.2 GB | Only Flash: Yes |
### Training memory (Llama 2 7B, batch=4)
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|---------|---------------|-----------------|-----------|
| 2K | 18.2 | 12.4 | 32% |
| 4K | 34.8 | 16.8 | 52% |
| 8K | OOM (>40GB) | 26.2 | Fits! |
## Scaling with sequence length
### Computational complexity
**Standard attention**:
- Time: O(N² × d)
- Memory: O(N² + N × d)
**Flash Attention**:
- Time: O(N² × d) (same, but with better constants)
- Memory: O(N × d) (linear!)
### Empirical scaling (A100, batch=1, heads=32, dim=64)
**Time per token (milliseconds)**:
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|----------|-----|-----|-----|-----|-----|------|
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
**Observation**: Speedup increases quadratically with sequence length!
### Memory per token (MB)
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|----------|-----|-----|-----|-----|-----|------|
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
**Observation**: Flash Attention memory per token is constant!
## Training vs inference performance
### Training (forward + backward, Llama 2 7B, A100)
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|-------------|------------------------|--------------------------|---------|
| 4 × 2K | 1.2 | 3.1 | 2.6x |
| 8 × 2K | 2.1 | 5.8 | 2.8x |
| 4 × 4K | 0.4 | 1.3 | 3.3x |
| 8 × 4K | OOM | 2.4 | Enabled |
| 2 × 8K | 0.1 | 0.4 | 4.0x |
### Inference (generation, Llama 2 7B, A100)
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|----------------|----------------------|-------------------------|---------|
| 512 | 48 | 52 | 1.1x |
| 2K | 42 | 62 | 1.5x |
| 4K | 31 | 58 | 1.9x |
| 8K | 18 | 51 | 2.8x |
| 16K | OOM | 42 | Enabled |
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
## Flash Attention versions comparison
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|--------|-----|-----|------------|-----------|
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
| TFLOPS | 180 | 420 | 740 | 1150 |
| GPU util % | 35% | 55% | 75% | 82% |
**Key improvements**:
- FA2: 2.3x faster than FA1 (better parallelism)
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
- FA3 (FP8): 2.6x faster than FA2 (low precision)
### Features by version
| Feature | FA1 | FA2 | FA3 |
|---------|-----|-----|-----|
| Basic attention | ✅ | ✅ | ✅ |
| Causal masking | ✅ | ✅ | ✅ |
| Multi-query attention | ❌ | ✅ | ✅ |
| Sliding window | ❌ | ✅ | ✅ |
| Paged KV cache | ❌ | ✅ | ✅ |
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
| Work partitioning | Basic | Advanced | Optimal |
## Real-world model benchmarks
### Llama 2 models (A100 80GB, batch=4, seq=2048)
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|-------|--------|------------------------|--------------------------|---------|
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
### GPT-style models (seq=1024)
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|-------|----------------------|-------------------------|---------|
| GPT-2 (124M) | 520 | 680 | 1.3x |
| GPT-J (6B) | 42 | 98 | 2.3x |
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
## Recommendations by use case
**Training large models (>7B parameters)**:
- Use Flash Attention 2 on A100
- Use Flash Attention 3 FP8 on H100 for maximum speed
- Expected: 2.5-3x speedup
**Long context inference (>4K tokens)**:
- Flash Attention essential (enables contexts standard attention can't handle)
- Expected: 2-4x speedup, 5-10x memory reduction
**Short sequences (<512 tokens)**:
- Flash Attention provides 1.2-1.5x speedup
- Minimal memory benefit
- Still worth enabling (no downside)
**Multi-user serving**:
- Flash Attention reduces per-request memory
- Allows higher concurrent batch sizes
- Can serve 2-3x more users on same hardware

View file

@ -0,0 +1,293 @@
# HuggingFace Transformers Integration
## Contents
- Enabling Flash Attention in Transformers
- Supported model architectures
- Configuration examples
- Performance comparisons
- Troubleshooting model-specific issues
## Enabling Flash Attention in Transformers
HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
**Simple enable for any supported model**:
```python
from transformers import AutoModel
model = AutoModel.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto"
)
```
**Install requirements**:
```bash
pip install transformers>=4.36
pip install flash-attn --no-build-isolation
```
## Supported model architectures
As of Transformers 4.40:
**Fully supported**:
- Llama / Llama 2 / Llama 3
- Mistral / Mixtral
- Falcon
- GPT-NeoX
- Phi / Phi-2 / Phi-3
- Qwen / Qwen2
- Gemma
- Starcoder2
- GPT-J
- OPT
- BLOOM
**Partially supported** (encoder-decoder):
- BART
- T5 / Flan-T5
- Whisper
**Check support**:
```python
from transformers import AutoConfig
config = AutoConfig.from_pretrained("model-name")
print(config._attn_implementation_internal)
# 'flash_attention_2' if supported
```
## Configuration examples
### Llama 2 with Flash Attention
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
model_id,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Generate
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_length=100)
print(tokenizer.decode(outputs[0]))
```
### Mistral with Flash Attention for long context
```python
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16, # Better for long context
device_map="auto",
max_position_embeddings=32768 # Extended context
)
# Process long document (32K tokens)
long_text = "..." * 10000
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=512)
```
### Fine-tuning with Flash Attention
```python
from transformers import Trainer, TrainingArguments
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16
)
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=3,
fp16=True, # Must match model dtype
optim="adamw_torch_fused" # Fast optimizer
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset
)
trainer.train()
```
### Multi-GPU training
```python
from transformers import AutoModelForCausalLM
import torch
# Model parallelism with Flash Attention
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto", # Automatic multi-GPU placement
max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
)
```
## Performance comparisons
### Memory usage (Llama 2 7B, batch=1)
| Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
|-----------------|-------------------|-------------------|-----------|
| 512 | 1.2 GB | 0.9 GB | 25% |
| 2048 | 3.8 GB | 1.4 GB | 63% |
| 8192 | 14.2 GB | 3.2 GB | 77% |
| 32768 | OOM (>24GB) | 10.8 GB | Fits! |
### Speed (tokens/sec, A100 80GB)
| Model | Standard | Flash Attn 2 | Speedup |
|-------|----------|--------------|---------|
| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
### Training throughput (samples/sec)
| Model | Batch Size | Standard | Flash Attn 2 | Speedup |
|-------|------------|----------|--------------|---------|
| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
## Troubleshooting model-specific issues
### Issue: Model doesn't support Flash Attention
Check support list above. If not supported, use PyTorch SDPA as fallback:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="sdpa", # PyTorch native (still faster)
torch_dtype=torch.float16
)
```
### Issue: CUDA out of memory during loading
Reduce memory footprint:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto",
max_memory={0: "18GB"}, # Reserve memory for KV cache
low_cpu_mem_usage=True
)
```
### Issue: Slower inference than expected
Ensure dtype matches:
```python
# Model and inputs must both be float16/bfloat16
model = model.to(torch.float16)
inputs = tokenizer(..., return_tensors="pt").to("cuda")
inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
for k, v in inputs.items()}
```
### Issue: Different outputs vs standard attention
Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
```python
# Compare outputs
model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
model_flash = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16
)
inputs = tokenizer("Test", return_tensors="pt").to("cuda")
with torch.no_grad():
out_standard = model_standard(**inputs).logits
out_flash = model_flash(**inputs).logits
diff = (out_standard - out_flash).abs().max()
print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
```
### Issue: ImportError during model loading
Install flash-attn:
```bash
pip install flash-attn --no-build-isolation
```
Or disable Flash Attention:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="eager", # Standard PyTorch
torch_dtype=torch.float16
)
```
## Best practices
1. **Always use float16/bfloat16** with Flash Attention (not float32)
2. **Set device_map="auto"** for automatic memory management
3. **Use bfloat16 for long context** (better numerical stability)
4. **Enable gradient checkpointing** for training large models
5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
**Example with all best practices**:
```python
from transformers import AutoModelForCausalLM, TrainingArguments
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16, # Better for training
device_map="auto",
low_cpu_mem_usage=True
)
# Enable gradient checkpointing for memory
model.gradient_checkpointing_enable()
# Training with optimizations
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
bf16=True, # Match model dtype
optim="adamw_torch_fused",
gradient_checkpointing=True
)
```

Some files were not shown because too many files have changed in this diff Show more