A partial re-implementation of Movement Pruning: Adaptive Sparsity by Fine-Tuning by Victor Sanh, Thomas Wolf, and Alexander M. Rush [arXiv:2005.07683]

The main goal of this notebook is to adapt Victor Sanh's implementation of movement pruning to:

  • Integrate with a custom trainer
  • Experiment with pruning on small datasets
  • Be compatible with v4 of the transformers library

Load libraries

%load_ext autoreload
%autoreload 2
import torch

import datasets
import transformers
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()

from datasets import load_dataset
from transformers import (AutoTokenizer, AutoModelForQuestionAnswering, default_data_collator, AdamW, 
                          get_linear_schedule_with_warmup)

from transformerlab.question_answering import *
from transformerlab.pruning import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__}")
print(f"Running on device: {device}")
Using transformers v4.1.1 and datasets v1.2.0
Running on device: cuda

Load data

As usual, we'll be using the SQuAD v1 dataset as our benchmark so let's quickly load it as follows:

squad_ds = load_dataset("squad")
squad_ds
DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

Next, let's tokenize and encode a subset so we can run the experiments more quickly:

model_ckpt = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

num_train_examples = 1600
num_eval_examples = 320
train_ds, eval_ds, eval_examples = convert_examples_to_features(squad_ds, tokenizer, num_train_examples, num_eval_examples)

Create the trainer

There are three main components we need in order to fine-prune with a Trainer:

  • A cubic sparsity scheduler to control the amount of pruning at each training step
  • Optimisation of the scores ${\bf S}$ after $T$ gradient updates
  • A loss that accounts for the current mask threshold

The following code does the trick:

class PruningTrainingArguments(QuestionAnsweringTrainingArguments):
    def __init__(self, *args, initial_threshold=1., final_threshold=0.1, initial_warmup=1, final_warmup=2, final_lambda=0.,
                 mask_scores_learning_rate=0., **kwargs): 
        super().__init__(*args, **kwargs)

        self.initial_threshold = initial_threshold
        self.final_threshold = final_threshold
        self.initial_warmup = initial_warmup
        self.final_warmup = final_warmup
        self.final_lambda = final_lambda
        self.mask_scores_learning_rate = mask_scores_learning_rate
class PruningTrainer(QuestionAnsweringTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        if self.args.max_steps > 0:
            self.t_total = self.args.max_steps
            self.args.num_train_epochs = self.args.max_steps // (len(self.get_train_dataloader()) // self.args.gradient_accumulation_steps) + 1
        else:
            self.t_total = len(self.get_train_dataloader()) // self.args.gradient_accumulation_steps * self.args.num_train_epochs
            
        
    def create_optimizer_and_scheduler(self, num_training_steps: int):
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if "mask_score" in n and p.requires_grad],
                "lr": self.args.mask_scores_learning_rate,
            },
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if "mask_score" not in n and p.requires_grad and not any(nd in n for nd in no_decay)
                ],
                "lr": self.args.learning_rate,
                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if "mask_score" not in n and p.requires_grad and any(nd in n for nd in no_decay)
                ],
                "lr": self.args.learning_rate,
                "weight_decay": 0.0,
            },
        ]

        self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        self.lr_scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.t_total
        )
        
        
    def compute_loss(self, model, inputs):
            
        threshold, regu_lambda = self._schedule_threshold(
            step=self.state.global_step+1,
            total_step=self.t_total,
            warmup_steps=self.args.warmup_steps,
            final_threshold=self.args.final_threshold,
            initial_threshold=self.args.initial_threshold,
            final_warmup=self.args.final_warmup,
            initial_warmup=self.args.initial_warmup,
            final_lambda=self.args.final_lambda,
        )
        inputs["threshold"] = threshold  
        outputs = model(**inputs)
        loss, start_logits_stu, end_logits_stu = outputs
        
        return loss
    
    def _schedule_threshold(
        self,
        step: int,
        total_step: int,
        warmup_steps: int,
        initial_threshold: float,
        final_threshold: float,
        initial_warmup: int,
        final_warmup: int,
        final_lambda: float,
    ):
        if step <= initial_warmup * warmup_steps:
            threshold = initial_threshold
        elif step > (total_step - final_warmup * warmup_steps):
            threshold = final_threshold
        else:
            spars_warmup_steps = initial_warmup * warmup_steps
            spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
            mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
            threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
        regu_lambda = final_lambda * threshold / final_threshold
        return threshold, regu_lambda

Configure the trainer

The next thing to do is configure the trainer. First, we need to use the special "masked" model and its configuration from the transformerlab.pruning module:

masked_config = MaskedBertConfig(pruning_method='topK', mask_init='constant', mask_scale=0.)

def model_init():
    return MaskedBertForQuestionAnswering.from_pretrained(model_ckpt, config=masked_config).to(device)

Here we're using a model_init function so that we can perform multiple runs wih the same trainer. Next we specify the hyperparameter that will be fixed across each run:

batch_size = 16
logging_steps = len(train_ds) // batch_size

# pruning params
initial_threshold = 1.
initial_warmup = 1
final_warmup = 2
final_lambda = 0

pruning_training_args = PruningTrainingArguments(
    output_dir="checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    logging_steps=logging_steps,
    initial_threshold=initial_threshold,
    initial_warmup=initial_warmup,
    final_warmup=final_warmup,
    final_lambda=final_lambda)
pruning_trainer = PruningTrainer(
    model_init=model_init,
    args=pruning_training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_examples,
    tokenizer=tokenizer
)

Next let's wrap the key hyperparameters in a function, noting that we need to add the final_threshold to the evaluation set with our current implementation:

def fine_prune(final_threshold, num_train_epochs, mask_scores_learning_rate=1e-2):
    eval_ds.reset_format()
    pruning_trainer.eval_dataset = eval_ds.map(lambda x : {'threshold': final_threshold})
    pruning_trainer.args.final_threshold = final_threshold
    pruning_trainer.args.mask_scores_learning_rate = mask_scores_learning_rate
    pruning_trainer.args.num_train_epochs = num_train_epochs
    pruning_trainer.args.warmup_steps = int(num_train_examples / batch_size * num_train_epochs * .1)
    print(f"Fine-pruning {(1-pruning_trainer.args.final_threshold)*100:.2f}% of weights with lr = {pruning_trainer.args.learning_rate} and mask_lr = {pruning_trainer.args.mask_scores_learning_rate} and {pruning_trainer.args.warmup_steps} warmup steps")
    pruning_trainer.train()

BERT-base experiments

0% pruning

As a baseline, let's set the final threshold to 1 (no pruning) and train for 5 epochs:

fine_prune(1., 3)
[303/303 06:53, Epoch 3/3]
Epoch Training Loss Validation Loss Exact Match F1
1.000000 4.200213 No log 31.562500 44.250428
2.000000 2.207515 No log 41.250000 53.709763
3.000000 1.592967 No log 44.062500 56.170856



10% pruning

fine_prune(0.9, 10)

30% pruning

50% pruning

70% pruning

90% pruning

DistilBERT experiments

Do Victor's "masked" model classes play nice with DistilBERT?