The main goal of this notebook is to adapt Victor Sanh's implementation of movement pruning to:
- Integrate with a custom trainer
- Experiment with pruning on small datasets
- Be compatible with v4 of the
transformers
library
%load_ext autoreload
%autoreload 2
import torch
import datasets
import transformers
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()
from datasets import load_dataset
from transformers import (AutoTokenizer, AutoModelForQuestionAnswering, default_data_collator, AdamW,
get_linear_schedule_with_warmup)
from transformerlab.question_answering import *
from transformerlab.pruning import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__}")
print(f"Running on device: {device}")
As usual, we'll be using the SQuAD v1 dataset as our benchmark so let's quickly load it as follows:
squad_ds = load_dataset("squad")
squad_ds
Next, let's tokenize and encode a subset so we can run the experiments more quickly:
model_ckpt = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
num_train_examples = 1600
num_eval_examples = 320
train_ds, eval_ds, eval_examples = convert_examples_to_features(squad_ds, tokenizer, num_train_examples, num_eval_examples)
Create the trainer
There are three main components we need in order to fine-prune with a Trainer
:
- A cubic sparsity scheduler to control the amount of pruning at each training step
- Optimisation of the scores ${\bf S}$ after $T$ gradient updates
- A loss that accounts for the current mask threshold
The following code does the trick:
class PruningTrainingArguments(QuestionAnsweringTrainingArguments):
def __init__(self, *args, initial_threshold=1., final_threshold=0.1, initial_warmup=1, final_warmup=2, final_lambda=0.,
mask_scores_learning_rate=0., **kwargs):
super().__init__(*args, **kwargs)
self.initial_threshold = initial_threshold
self.final_threshold = final_threshold
self.initial_warmup = initial_warmup
self.final_warmup = final_warmup
self.final_lambda = final_lambda
self.mask_scores_learning_rate = mask_scores_learning_rate
class PruningTrainer(QuestionAnsweringTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.args.max_steps > 0:
self.t_total = self.args.max_steps
self.args.num_train_epochs = self.args.max_steps // (len(self.get_train_dataloader()) // self.args.gradient_accumulation_steps) + 1
else:
self.t_total = len(self.get_train_dataloader()) // self.args.gradient_accumulation_steps * self.args.num_train_epochs
def create_optimizer_and_scheduler(self, num_training_steps: int):
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if "mask_score" in n and p.requires_grad],
"lr": self.args.mask_scores_learning_rate,
},
{
"params": [
p
for n, p in self.model.named_parameters()
if "mask_score" not in n and p.requires_grad and not any(nd in n for nd in no_decay)
],
"lr": self.args.learning_rate,
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in self.model.named_parameters()
if "mask_score" not in n and p.requires_grad and any(nd in n for nd in no_decay)
],
"lr": self.args.learning_rate,
"weight_decay": 0.0,
},
]
self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.t_total
)
def compute_loss(self, model, inputs):
threshold, regu_lambda = self._schedule_threshold(
step=self.state.global_step+1,
total_step=self.t_total,
warmup_steps=self.args.warmup_steps,
final_threshold=self.args.final_threshold,
initial_threshold=self.args.initial_threshold,
final_warmup=self.args.final_warmup,
initial_warmup=self.args.initial_warmup,
final_lambda=self.args.final_lambda,
)
inputs["threshold"] = threshold
outputs = model(**inputs)
loss, start_logits_stu, end_logits_stu = outputs
return loss
def _schedule_threshold(
self,
step: int,
total_step: int,
warmup_steps: int,
initial_threshold: float,
final_threshold: float,
initial_warmup: int,
final_warmup: int,
final_lambda: float,
):
if step <= initial_warmup * warmup_steps:
threshold = initial_threshold
elif step > (total_step - final_warmup * warmup_steps):
threshold = final_threshold
else:
spars_warmup_steps = initial_warmup * warmup_steps
spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
regu_lambda = final_lambda * threshold / final_threshold
return threshold, regu_lambda
The next thing to do is configure the trainer. First, we need to use the special "masked" model and its configuration from the transformerlab.pruning
module:
masked_config = MaskedBertConfig(pruning_method='topK', mask_init='constant', mask_scale=0.)
def model_init():
return MaskedBertForQuestionAnswering.from_pretrained(model_ckpt, config=masked_config).to(device)
Here we're using a model_init
function so that we can perform multiple runs wih the same trainer. Next we specify the hyperparameter that will be fixed across each run:
batch_size = 16
logging_steps = len(train_ds) // batch_size
# pruning params
initial_threshold = 1.
initial_warmup = 1
final_warmup = 2
final_lambda = 0
pruning_training_args = PruningTrainingArguments(
output_dir="checkpoints",
evaluation_strategy = "epoch",
learning_rate=3e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
logging_steps=logging_steps,
initial_threshold=initial_threshold,
initial_warmup=initial_warmup,
final_warmup=final_warmup,
final_lambda=final_lambda)
pruning_trainer = PruningTrainer(
model_init=model_init,
args=pruning_training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
eval_examples=eval_examples,
tokenizer=tokenizer
)
Next let's wrap the key hyperparameters in a function, noting that we need to add the final_threshold
to the evaluation set with our current implementation:
def fine_prune(final_threshold, num_train_epochs, mask_scores_learning_rate=1e-2):
eval_ds.reset_format()
pruning_trainer.eval_dataset = eval_ds.map(lambda x : {'threshold': final_threshold})
pruning_trainer.args.final_threshold = final_threshold
pruning_trainer.args.mask_scores_learning_rate = mask_scores_learning_rate
pruning_trainer.args.num_train_epochs = num_train_epochs
pruning_trainer.args.warmup_steps = int(num_train_examples / batch_size * num_train_epochs * .1)
print(f"Fine-pruning {(1-pruning_trainer.args.final_threshold)*100:.2f}% of weights with lr = {pruning_trainer.args.learning_rate} and mask_lr = {pruning_trainer.args.mask_scores_learning_rate} and {pruning_trainer.args.warmup_steps} warmup steps")
pruning_trainer.train()
fine_prune(1., 3)
fine_prune(0.9, 10)