A partial reimplementation of DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter by Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf [arXiv:1910.01108]

The goal of this notebook is to explore task-specific knowledge distillation, where a teacher is used to augment the cross-entropy loss of the student during fine-tuning:

$${\cal L}(\mathbf{x}|T) = - \sum_i \bar{y}_i\log y_i(\mathbf{x}|T) -T^2 \sum_i \hat{y}_i(\mathbf{x}|T)\log y_i(\mathbf{x}|T) \,.$$

Here $T$ is the temperature, $\hat{y}$ are the outputs from the model, $\bar{y}$ the ground-truth labels, and $y_i$ a softmax with temperature.

This neat idea comes from the DistilBERT paper, where the authors found that including a "second step of distillation" produced a student that performed better than simply fine-tuning the distilled language model:

We also studied whether we could add another step of distillation during the adaptation phase by fine-tuning DistilBERT on SQuAD using a BERT model previously fine-tuned on SQuAD as a teacher for an additional term in the loss (knowledge distillation). In this setting, there are thus two successive steps of distillation, one during the pre-training phase and one during the adaptation phase. In this case, we were able to reach interesting performances given the size of the model:79.8 F1 and 70.4 EM, i.e. within 3 points of the full model. We'll take the same approach here and aim to reproduce the SQuAD v1 results from the paper. The results are summarised in the table below, where each entry refers to the Exact Match / F1-score on the validation set:

Implementation BERT-base DistilBERT (DistilBERT)^2
HuggingFace 81.2 / 88.5 77.7 / 85.8 79.1 / 86.9
Ours 80.1 / 87.8 76.7 / 85.2 78.4 / 86.5

Load libraries

%load_ext autoreload
%autoreload 2
import math
from pprint import pprint

import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
import datasets
import transformers
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()

from datasets import load_dataset, load_metric
from transformers import (AutoTokenizer, AutoModelForQuestionAnswering, 
                          default_data_collator, QuestionAnsweringPipeline)

from transformerlab.question_answering import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__}")
print(f"Running on device: {device}")
Using transformers v4.1.1 and datasets v1.2.0
Running on device: cuda

Load and inspect data

squad_ds = load_dataset("squad")
squad_ds
DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

The key information contained in each example is the context, question and answers fields:

pprint(squad_ds['train'][0])
{'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']},
 'context': 'Architecturally, the school has a Catholic character. Atop the '
            "Main Building's gold dome is a golden statue of the Virgin Mary. "
            'Immediately in front of the Main Building and facing it, is a '
            'copper statue of Christ with arms upraised with the legend '
            '"Venite Ad Me Omnes". Next to the Main Building is the Basilica '
            'of the Sacred Heart. Immediately behind the basilica is the '
            'Grotto, a Marian place of prayer and reflection. It is a replica '
            'of the grotto at Lourdes, France where the Virgin Mary reputedly '
            'appeared to Saint Bernadette Soubirous in 1858. At the end of the '
            'main drive (and in a direct line that connects through 3 statues '
            'and the Gold Dome), is a simple, modern stone statue of Mary.',
 'id': '5733be284776f41900661182',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes '
             'France?',
 'title': 'University_of_Notre_Dame'}

Note that there is only one possible answer per examples in the training set (i.e. answers.answer_start has one element), but multiple possible answers in the validation set:

pprint(squad_ds['validation'][666])
{'answers': {'answer_start': [69, 69, 73],
             'text': ['the national anthem',
                      'the national anthem',
                      'national anthem']},
 'context': 'Six-time Grammy winner and Academy Award nominee Lady Gaga '
            'performed the national anthem, while Academy Award winner Marlee '
            'Matlin provided American Sign Language (ASL) translation.',
 'id': '56bec6ac3aeaaa14008c9400',
 'question': 'What did Marlee Matlin translate?',
 'title': 'Super_Bowl_50'}

We can look at the frequencies by using a little bit of Dataset.map and pandas magic:

answers_ds = squad_ds.map(lambda x : {'num_possible_answers' : pd.Series(x['answers']['answer_start']).nunique()})
answers_ds.set_format('pandas')
answers_df = answers_ds['validation'][:]
answers_df['num_possible_answers'].value_counts()
1    6238
2    3498
3     754
4      71
5       9
Name: num_possible_answers, dtype: int64

Fine-tune BERT-base

The first thing we need to do is fine-tune BERT-base so that we can use it as a teacher during the distillation step. To do so, let's tokenize and encode the texts using our helper functions from the transformerlab.question_answering module:

teacher_model_checkpoint = "bert-base-uncased"
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_checkpoint)

num_train_examples = len(squad_ds['train'])
num_eval_examples = len(squad_ds['validation'])
# note: convert_examples_to_features shuffles the train / eval datasets
train_ds, eval_ds, eval_examples = convert_examples_to_features(squad_ds, teacher_tokenizer, num_train_examples, num_eval_examples)

Each encoded training examples contains the usual input_ids, attention_mask and token_type_ids associated with the BERT tokenizer, along with start_positions and end_positions to denote the span of text where the answer lies:

teacher_tokenizer.decode(train_ds[0]['input_ids'], skip_special_tokens=True)
'what percentage of egyptians polled support death penalty for those leaving islam? the pew forum on religion & public life ranks egypt as the fifth worst country in the world for religious freedom. the united states commission on international religious freedom, a bipartisan independent agency of the us government, has placed egypt on its watch list of countries that require close monitoring due to the nature and extent of violations of religious freedom engaged in or tolerated by the government. according to a 2010 pew global attitudes survey, 84 % of egyptians polled supported the death penalty for those who leave islam ; 77 % supported whippings and cutting off of hands for theft and robbery ; and 82 % support stoning a person who commits adultery.'
train_ds[0]['start_positions'], train_ds[0]['end_positions']
(97, 98)

Note that the input_ids contain both the question and context. As a sanity check, let's see that we can recover the original text by decoding the input_ids in one of the validation examples:

teacher_tokenizer.decode(eval_ds[0]['input_ids'], skip_special_tokens=True)
'in what year did massachusetts first require children to be educated in schools? private schooling in the united states has been debated by educators, lawmakers and parents, since the beginnings of compulsory education in massachusetts in 1852. the supreme court precedent appears to favor educational choice, so long as states may set standards for educational accomplishment. some of the most relevant supreme court case law on this is as follows : runyon v. mccrary, 427 u. s. 160 ( 1976 ) ; wisconsin v. yoder, 406 u. s. 205 ( 1972 ) ; pierce v. society of sisters, 268 u. s. 510 ( 1925 ) ; meyer v. nebraska, 262 u. s. 390 ( 1923 ).'
pprint(eval_examples[0])
{'answers': {'answer_start': [158, 158, 158], 'text': ['1852', '1852', '1852']},
 'context': 'Private schooling in the United States has been debated by '
            'educators, lawmakers and parents, since the beginnings of '
            'compulsory education in Massachusetts in 1852. The Supreme Court '
            'precedent appears to favor educational choice, so long as states '
            'may set standards for educational accomplishment. Some of the '
            'most relevant Supreme Court case law on this is as follows: '
            'Runyon v. McCrary, 427 U.S. 160 (1976); Wisconsin v. Yoder, 406 '
            'U.S. 205 (1972); Pierce v. Society of Sisters, 268 U.S. 510 '
            '(1925); Meyer v. Nebraska, 262 U.S. 390 (1923).',
 'id': '572759665951b619008f8884',
 'question': 'In what year did Massachusetts first require children to be '
             'educated in schools?',
 'title': 'Private_school'}

This looks good, so let's move on to fine-tuning!

Configure and initialise the trainer

Fine-tuning Transformers for extractive question answering involves a significant amount of postprocessing to map the model's logits to spans of text for the predicted answers. Again we'll use the custom trainer from transformerlab.question_answering to do this for us. First we need to specify the training arguments:

batch_size = 16
logging_steps = len(train_ds) // batch_size

teacher_args = QuestionAnsweringTrainingArguments(
    output_dir="checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_steps=logging_steps)

print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_examples.num_rows}")
print(f"Logging steps: {logging_steps}")
Number of training examples: 88524
Number of validation examples: 10784
Number of raw validation examples: 10570
Logging steps: 5532

Next we instantiate the trainer:

def teacher_init():
    return AutoModelForQuestionAnswering.from_pretrained(teacher_model_checkpoint)

teacher_trainer = QuestionAnsweringTrainer(
    model_init=teacher_init,
    args=teacher_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_examples,
    tokenizer=teacher_tokenizer)

and fine-tune (this takes around 2h on a single 16GB GPU):

teacher_trainer.train();

Finally, we save the model so we can upload it to the HuggingFace model hub:

teacher_trainer.save_model('models/bert-base-uncased-finetuned-squad-v1')

Evaluate fine-tuned model

Now that we've fine-tuned BERT-base on SQuAD, we can easily evaluate it by downloading from the Hub and initialising a new trainer:

teacher_checkpoint = "lewtun/bert-base-uncased-finetuned-squad-v1"
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_checkpoint)
teacher_finetuned = AutoModelForQuestionAnswering.from_pretrained(teacher_checkpoint)

teacher_trainer = QuestionAnsweringTrainer(
    model=teacher_finetuned,
    args=teacher_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_examples,
    tokenizer=teacher_tokenizer)

teacher_trainer.evaluate()
[674/674 02:09]

{'eval_loss': 'No log',
 'eval_exact_match': 80.07568590350047,
 'eval_f1': 87.77870284880602}

These scores are within 1% of the values quoted in the DistilBERT paper, and probably due to slightly different choices of the hyperparameters. Let's move on to fine-tuning DistilBERT!

Fine-tune DistilBERT

Here we follow the same steps as we did for BERT-base, beginning with tokenizing the datasets:

distilbert_checkpoint = "distilbert-base-uncased"
distilbert_tokenizer = AutoTokenizer.from_pretrained(distilbert_checkpoint)

num_train_examples = len(squad_ds['train'])
num_eval_examples = len(squad_ds['validation'])
train_ds, eval_ds, eval_examples = convert_examples_to_features(squad_ds, distilbert_tokenizer, num_train_examples, num_eval_examples)

Next, we configure and initialise the trainer:

batch_size = 16
logging_steps = len(train_ds) // batch_size

distilbert_args = QuestionAnsweringTrainingArguments(
    output_dir="checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False
)

print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_examples.num_rows}")
print(f"Logging steps: {logging_steps}")

def distilbert_init():
    return AutoModelForQuestionAnswering.from_pretrained(distilbert_checkpoint)

distilbert_trainer = QuestionAnsweringTrainer(
    model_init=distilbert_init,
    args=distilbert_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_examples,
    tokenizer=distilbert_tokenizer
)

Fine-tuning takes about 1.5h on a single 16GB GPU:

distilbert_trainer.train();
distilbert_trainer.save_model('models/bert-base-uncased-finetuned-squad-v1')

Distilling DistilBERT

The main thing we need to implement task-specific distillation is augment the standard cross-entropy loss with a distillation term (see above equation). We can implement this by overriding the compute_loss method of the QuestionAnsweringTrainer, but first let's define the training arguments we'll need:

class DistillationTrainingArguments(QuestionAnsweringTrainingArguments):
    def __init__(self, *args, alpha_ce=0.5, alpha_distil=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.alpha_ce = alpha_ce
        self.alpha_distil = alpha_distil
        self.temperature = temperature

For the trainer, we'll need a few ingredients:

  • We need two models (a teacher and student), and since the model attribute is the one that is optimized, we'll just add an attribute for the teacher
  • When we pass the question and context to the student or teacher, we get a range of scores (logits) for the start and end positions. Since we want to minimize the distance between the teacher and student predictions , we'll use the KL-divergence as our distillation loss
  • Once the distillation loss is computed, we take a linear combination with the cross-entropy to obtain our final loss function

The following code does the trick:

class DistillationTrainer(QuestionAnsweringTrainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.teacher.eval()
        self.train_dataset.set_format(
            type=self.train_dataset.format["type"], columns=list(self.train_dataset.features.keys()))
        
    def compute_loss(self, model, inputs):
        inputs_stu = {
            "input_ids": inputs['input_ids'],
            "attention_mask": inputs['attention_mask'],
            "start_positions": inputs['start_positions'],
            "end_positions": inputs['end_positions'],
            }
        outputs_stu = model(**inputs_stu)
        loss = outputs_stu.loss
        start_logits_stu = outputs_stu.start_logits
        end_logits_stu = outputs_stu.end_logits
        
        with torch.no_grad():
            outputs_tea = self.teacher(
                input_ids=inputs["input_ids"], 
                token_type_ids=inputs["token_type_ids"],
                attention_mask=inputs["attention_mask"])
            start_logits_tea = outputs_tea.start_logits
            end_logits_tea = outputs_tea.end_logits
        assert start_logits_tea.size() == start_logits_stu.size()
        assert end_logits_tea.size() == end_logits_stu.size()
        
        loss_fct = nn.KLDivLoss(reduction="batchmean")
        loss_start = (loss_fct(
            F.log_softmax(start_logits_stu / self.args.temperature, dim=-1),
            F.softmax(start_logits_tea / self.args.temperature, dim=-1)) * (self.args.temperature ** 2))
        loss_end = (loss_fct(
            F.log_softmax(end_logits_stu / self.args.temperature, dim=-1),
            F.softmax(end_logits_tea / self.args.temperature, dim=-1)) * (self.args.temperature ** 2))
        loss_logits = (loss_start + loss_end) / 2.0
        loss = self.args.alpha_distil * loss_logits + self.args.alpha_distil * loss
        return loss

It's then a similar process to configure and initialise the trainer:

batch_size = 16
logging_steps = len(train_ds) // batch_size

student_training_args = DistillationTrainingArguments(
    output_dir=f"checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=logging_steps,
)

print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_examples.num_rows}")
print(f"Logging steps: {logging_steps}")
Number of training examples: 88524
Number of validation examples: 10784
Number of raw validation examples: 10570
Logging steps: 5532
teacher_checkpoint = "lewtun/bert-base-uncased-finetuned-squad-v1"
student_checkpoint = "distilbert-base-uncased"
teacher_model = AutoModelForQuestionAnswering.from_pretrained(teacher_checkpoint).to(device)
student_model = AutoModelForQuestionAnswering.from_pretrained(student_checkpoint).to(device)
student_tokenizer = AutoTokenizer.from_pretrained(student_checkpoint)

distil_trainer = DistillationTrainer(
    model=student_model,
    teacher_model=teacher_model,
    args=student_training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    eval_examples=eval_examples,
    tokenizer=student_tokenizer)
distil_trainer.train()
[16599/16599 2:27:26, Epoch 3/3]
Epoch Training Loss Validation Loss Exact Match F1
1.000000 1.559650 No log 76.026490 84.654177
2.000000 0.783048 No log 77.861873 85.693001
3.000000 0.619299 No log 78.344371 86.211550



TrainOutput(global_step=16599, training_loss=0.9873095414488996)
distil_trainer.save_model('models/distilbert-base-uncased-distilled-squad-v1')

Speed test

As a simple benchmark, here we compare the time it takes for our teacher and student to generate 1,000 predictions on a CPU (to simulate a production environment). First, we load our fine-tuned models:

student_model_ckpt = 'lewtun/distilbert-base-uncased-distilled-squad-v1'
teacher_model_ckpt = 'lewtun/bert-base-uncased-finetuned-squad-v1'

student_tokenizer = AutoTokenizer.from_pretrained(student_model_ckpt)
student_model = AutoModelForQuestionAnswering.from_pretrained(student_model_ckpt).to('cpu')

teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_ckpt)
teacher_model = AutoModelForQuestionAnswering.from_pretrained(teacher_model_ckpt).to('cpu')

Next we create two pipelines for the student and teacher:

student_pipe = QuestionAnsweringPipeline(student_model, student_tokenizer)
teacher_pipe = QuestionAnsweringPipeline(teacher_model, teacher_tokenizer)

And then run the inference test:

%%time

for idx in range(1000):
    context = squad_ds['validation'][idx]['context']
    question = squad_ds['validation'][idx]['question']
    teacher_pipe(question=question, context=context)
CPU times: user 39min 55s, sys: 34.8 s, total: 40min 29s
Wall time: 6min 7s
%%time

for idx in range(1000):
    context = squad_ds['validation'][idx]['context']
    question = squad_ds['validation'][idx]['question']
    student_pipe(question=question, context=context)
CPU times: user 19min 47s, sys: 13.3 s, total: 20min 1s
Wall time: 3min 1s

From this example, we see roughly a 2x speedup from using a distilled model with less than 3% drop in Exact Match / F1-score!