Part 3 — Fine-Tuning¶
Overview
Fine-tuning adapts a pre-trained model to your task. This article walks through the full spectrum: full fine-tuning, parameter-efficient methods (LoRA, QLoRA), supervised instruction tuning with trl, and the principles behind RLHF.
1. Taxonomy of Fine-Tuning Approaches¶
Full Fine-Tuning
├── All parameters updated
└── Most expensive, best ceiling
Parameter-Efficient Fine-Tuning (PEFT)
├── LoRA — inject low-rank matrices
├── QLoRA — quantised base + LoRA
├── Prefix Tuning — prepend learnable tokens
└── Adapters — small bottleneck layers
Instruction Fine-Tuning (SFT)
└── Supervised training on (prompt, response) pairs
Alignment
├── RLHF — reward model + PPO
└── DPO — direct preference optimisation
2. Full Fine-Tuning¶
All model weights are updated. Use when you have enough compute and enough data.
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased",
num_labels=2,
)
ds = load_dataset("imdb")
args = TrainingArguments(
output_dir="checkpoints/distilbert-imdb",
num_train_epochs=3,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
learning_rate=2e-5,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
bf16=True,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=ds["train"],
eval_dataset=ds["test"],
)
trainer.train()
Catastrophic forgetting
Full fine-tuning on a small dataset can destroy the model's general capabilities. Use a low LR (2e-5 to 5e-5), a short schedule, and early stopping.
3. LoRA — Low-Rank Adaptation¶
LoRA freezes the original weights and injects two small trainable matrices (A, B) into each target layer. The update is ΔW = A × B where A ∈ ℝ^(d×r) and B ∈ ℝ^(r×k), with rank r << d.
This reduces trainable parameters by 10,000× for large models.
# pip install peft
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
torch_dtype="auto",
device_map="auto",
)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16, # rank — higher = more capacity, more params
lora_alpha=32, # scaling factor (alpha/r is the effective scale)
target_modules=["q_proj", "v_proj"], # which layers to inject LoRA into
lora_dropout=0.05,
bias="none",
)
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
# trainable params: 1,310,720 || all params: 1,236,137,984 || trainable%: 0.11%
Choosing LoRA Targets¶
| Model Family | Recommended Targets |
|---|---|
| LLaMA / Mistral | q_proj, v_proj (minimum), add k_proj, o_proj, gate_proj for more capacity |
| GPT-2 | c_attn, c_proj |
| BERT/RoBERTa | query, value |
| Falcon | query_key_value, dense |
Saving & Merging LoRA Weights¶
# Save only the LoRA adapter (a few MB)
model.save_pretrained("adapters/llama-lora")
tokenizer.save_pretrained("adapters/llama-lora")
# Merge LoRA into base weights for faster inference
merged = model.merge_and_unload()
merged.save_pretrained("models/llama-merged")
4. QLoRA — Quantised LoRA¶
QLoRA loads the base model in 4-bit (NF4) quantisation, then trains LoRA adapters in bf16. This lets you fine-tune a 7B model on a single 16 GB GPU.
# pip install bitsandbytes peft transformers accelerate
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
# 1. Load base model in 4-bit
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4 — best for normally distributed weights
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True, # quantise the quantisation constants too (saves ~0.4 GB)
)
base_model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
quantization_config=bnb_config,
device_map="auto",
)
# 2. Prepare for k-bit training (enables gradient checkpointing, casts norms to fp32)
base_model = prepare_model_for_kbit_training(base_model)
# 3. Attach LoRA
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
5. Supervised Fine-Tuning (SFT) with trl¶
SFTTrainer wraps Trainer with instruction-tuning convenience: chat template formatting, completion-only loss masking, and packing.
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import torch
# Load tokenizer and model (QLoRA setup)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
quantization_config=bnb_config,
device_map="auto",
)
lora_config = LoraConfig(
r=16, lora_alpha=32,
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)
# Dataset — must have a "text" column (or configure formatting_func)
ds = load_dataset("json", data_files="data/instructions.jsonl", split="train")
def format_prompt(example):
return {"text": f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}"}
ds = ds.map(format_prompt)
sft_config = SFTConfig(
output_dir="checkpoints/mistral-sft",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=16,
optim="paged_adamw_8bit", # memory-efficient optimizer for QLoRA
learning_rate=2e-4,
fp16=False,
bf16=True,
max_grad_norm=0.3,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
logging_steps=25,
save_steps=500,
max_seq_length=2048,
packing=False, # set True to pack short examples for efficiency
report_to="wandb",
)
trainer = SFTTrainer(
model=model,
train_dataset=ds,
peft_config=lora_config,
processing_class=tokenizer,
args=sft_config,
)
trainer.train()
trainer.save_model()
6. Completion-Only Loss Masking¶
By default the loss is computed over the full sequence including the instruction. Masking user tokens means the model only learns to generate the response.
from trl import DataCollatorForCompletionOnlyLM
# The response template is the string that separates instruction from response
response_template = "### Response:"
collator = DataCollatorForCompletionOnlyLM(
response_template=response_template,
tokenizer=tokenizer,
)
trainer = SFTTrainer(
...
data_collator=collator,
)
7. Direct Preference Optimisation (DPO)¶
DPO replaces the reward model + PPO pipeline of RLHF with a simpler closed-form objective trained directly on preference pairs (chosen, rejected).
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
# Dataset must have columns: prompt, chosen, rejected
ds = load_dataset("json", data_files="data/preferences.jsonl", split="train")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto",
)
lora_config = LoraConfig(r=16, lora_alpha=32, task_type="CAUSAL_LM")
dpo_config = DPOConfig(
output_dir="checkpoints/llama-dpo",
beta=0.1, # KL penalty coefficient — higher = stay closer to reference
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=5e-5,
bf16=True,
logging_steps=10,
report_to="wandb",
)
trainer = DPOTrainer(
model=model,
ref_model=None, # None uses PEFT reference automatically
args=dpo_config,
train_dataset=ds,
processing_class=tokenizer,
peft_config=lora_config,
)
trainer.train()
8. Fast Fine-Tuning with Unsloth¶
Unsloth rewrites the attention kernels and backward pass in Triton, giving ~2× faster training and ~70% less VRAM than a plain peft + trl setup — with no accuracy loss. It wraps the same SFTTrainer / DPOTrainer API so the rest of your code stays unchanged.
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps trl peft accelerate bitsandbytes
8.1 LoRA / QLoRA Setup¶
from unsloth import FastLanguageModel
import torch
# Loads model in 4-bit and patches kernels automatically
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/llama-3-8b-bnb-4bit", # 4-bit pre-quantised hub models
max_seq_length=2048,
dtype=None, # auto: bfloat16 on Ampere+, float16 otherwise
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0, # Unsloth recommends 0; uses a fused kernel path
bias="none",
use_gradient_checkpointing="unsloth", # 30% less VRAM than HF's impl
random_state=42,
)
model.print_trainable_parameters()
# trainable params: 41,943,040 || all params: 8,072,884,224 || trainable%: 0.52%
Pre-quantised models on the Hub
Unsloth publishes 4-bit versions of popular models (unsloth/mistral-7b-bnb-4bit, unsloth/llama-3-8b-bnb-4bit, unsloth/gemma-7b-bnb-4bit, etc.). Loading these is faster than quantising on the fly with bitsandbytes.
8.2 SFT with Unsloth¶
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/llama-3-8b-bnb-4bit",
max_seq_length=2048,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
)
# Alpaca-style prompt template
PROMPT_TEMPLATE = """\
Below is an instruction. Write a response.
### Instruction:
{}
### Response:
{}"""
def format_prompts(examples):
texts = []
for instruction, output in zip(examples["instruction"], examples["output"]):
texts.append(PROMPT_TEMPLATE.format(instruction, output) + tokenizer.eos_token)
return {"text": texts}
ds = load_dataset("yahma/alpaca-cleaned", split="train")
ds = ds.map(format_prompts, batched=True)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=ds,
args=SFTConfig(
dataset_text_field="text",
max_seq_length=2048,
output_dir="checkpoints/llama-unsloth-sft",
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
num_train_epochs=3,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
logging_steps=25,
report_to="wandb",
),
)
trainer.train()
# Save adapter (a few MB)
model.save_pretrained("adapters/llama-unsloth")
tokenizer.save_pretrained("adapters/llama-unsloth")
# Optional: merge and export
model.save_pretrained_merged(
"models/llama-unsloth-merged",
tokenizer,
save_method="merged_16bit", # or "merged_4bit", "lora"
)
# Export to GGUF for llama.cpp / Ollama
model.save_pretrained_gguf("models/llama-unsloth-gguf", tokenizer, quantization_method="q4_k_m")
save_pretrained_merged vs merge_and_unload
Unsloth's save_pretrained_merged is the equivalent of PEFT's merge_and_unload().save_pretrained(...) but also supports direct GGUF export for on-device inference.
8.3 DPO with Unsloth¶
from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth import PatchDPOTrainer # patches trl's DPOTrainer for Unsloth kernels
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
PatchDPOTrainer() # call before instantiating DPOTrainer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/llama-3-8b-instruct-bnb-4bit",
max_seq_length=2048,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
)
# Dataset must have columns: prompt, chosen, rejected
ds = load_dataset("json", data_files="data/preferences.jsonl", split="train")
trainer = DPOTrainer(
model=model,
ref_model=None, # None → PEFT reference (no extra VRAM)
processing_class=tokenizer,
train_dataset=ds,
args=DPOConfig(
output_dir="checkpoints/llama-unsloth-dpo",
beta=0.1,
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=5e-5,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
logging_steps=10,
report_to="wandb",
),
)
trainer.train()
trainer.save_model("checkpoints/llama-unsloth-dpo/final")
8.4 Unsloth vs Plain PEFT+TRL — Quick Comparison¶
Plain peft + trl |
Unsloth | |
|---|---|---|
| Training speed | Baseline | ~1.5–2× faster |
| VRAM for 7B (4-bit) | ~16 GB | ~9–11 GB |
| API compatibility | — | Same SFTTrainer / DPOTrainer |
| GGUF export | Manual (llama.cpp) | Built-in |
| Windows support | Yes | Linux / WSL2 / Colab only |
| Model support | All HF models | LLaMA, Mistral, Gemma, Phi, Qwen, Falcon, … |
9. Hyperparameter Guide¶
| Hyperparameter | Typical Range | Notes |
|---|---|---|
| Learning rate | 1e-5 – 3e-4 |
Lower for full fine-tune, higher for LoRA |
LoRA rank r |
8 – 128 | Higher = more capacity, more VRAM |
lora_alpha |
r – 2r |
alpha/r is the effective learning rate scale |
| Batch size | 8 – 128 (effective) | Use gradient accumulation to reach effective size |
| Epochs | 1 – 5 | More data → fewer epochs; overfit with 3+ on small datasets |
| Warmup | 3% – 10% of steps | Use ~6% as default |
| DPO beta | 0.01 – 0.5 | Higher = more conservative; start with 0.1 |
10. Fine-Tuning Checklist¶
- Base model selected and tested for baseline performance
- Dataset formatted with correct chat template
- Completion-only masking enabled (for instruction tuning)
- LoRA rank and targets chosen based on available VRAM
- Gradient checkpointing enabled (
training_args.gradient_checkpointing=True) - Eval loss monitored every epoch; early stopping configured
- Final adapter saved and can be loaded back successfully
- Merged model tested end-to-end at inference