The goal of this notebook is to explore task-specific knowledge distillation, where a teacher is used to augment the cross-entropy loss of the student during fine-tuning:
$${\cal L}(\mathbf{x}|T) = - \sum_i \bar{y}_i\log y_i(\mathbf{x}|T) -T^2 \sum_i \hat{y}_i(\mathbf{x}|T)\log y_i(\mathbf{x}|T) \,.$$
Here $T$ is the temperature, $\hat{y}$ are the outputs from the model, $\bar{y}$ the ground-truth labels, and $y_i$ a softmax with temperature.
This neat idea comes from the DistilBERT paper, where the authors found that including a "second step of distillation" produced a student that performed better than simply fine-tuning the distilled language model:
We also studied whether we could add another step of distillation during the adaptation phase by fine-tuning DistilBERT on SQuAD using a BERT model previously fine-tuned on SQuAD as a teacher for an additional term in the loss (knowledge distillation). In this setting, there are thus two successive steps of distillation, one during the pre-training phase and one during the adaptation phase. In this case, we were able to reach interesting performances given the size of the model:79.8 F1 and 70.4 EM, i.e. within 3 points of the full model. We'll take the same approach here and aim to reproduce the SQuAD v1 results from the paper. The results are summarised in the table below, where each entry refers to the Exact Match / F1-score on the validation set:
Implementation | BERT-base | DistilBERT | (DistilBERT)^2 |
---|---|---|---|
HuggingFace | 81.2 / 88.5 | 77.7 / 85.8 | 79.1 / 86.9 |
Ours | 80.1 / 87.8 | 76.7 / 85.2 | 78.4 / 86.5 |
%load_ext autoreload
%autoreload 2
import math
from pprint import pprint
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import datasets
import transformers
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()
from datasets import load_dataset, load_metric
from transformers import (AutoTokenizer, AutoModelForQuestionAnswering,
default_data_collator, QuestionAnsweringPipeline)
from transformerlab.question_answering import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__}")
print(f"Running on device: {device}")
squad_ds = load_dataset("squad")
squad_ds
The key information contained in each example is the context
, question
and answers
fields:
pprint(squad_ds['train'][0])
Note that there is only one possible answer per examples in the training set (i.e. answers.answer_start
has one element), but multiple possible answers in the validation set:
pprint(squad_ds['validation'][666])
We can look at the frequencies by using a little bit of Dataset.map
and pandas
magic:
answers_ds = squad_ds.map(lambda x : {'num_possible_answers' : pd.Series(x['answers']['answer_start']).nunique()})
answers_ds.set_format('pandas')
answers_df = answers_ds['validation'][:]
answers_df['num_possible_answers'].value_counts()
The first thing we need to do is fine-tune BERT-base so that we can use it as a teacher during the distillation step. To do so, let's tokenize and encode the texts using our helper functions from the transformerlab.question_answering
module:
teacher_model_checkpoint = "bert-base-uncased"
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_checkpoint)
num_train_examples = len(squad_ds['train'])
num_eval_examples = len(squad_ds['validation'])
# note: convert_examples_to_features shuffles the train / eval datasets
train_ds, eval_ds, eval_examples = convert_examples_to_features(squad_ds, teacher_tokenizer, num_train_examples, num_eval_examples)
Each encoded training examples contains the usual input_ids
, attention_mask
and token_type_ids
associated with the BERT tokenizer, along with start_positions
and end_positions
to denote the span of text where the answer lies:
teacher_tokenizer.decode(train_ds[0]['input_ids'], skip_special_tokens=True)
train_ds[0]['start_positions'], train_ds[0]['end_positions']
Note that the input_ids
contain both the question and context. As a sanity check, let's see that we can recover the original text by decoding the input_ids
in one of the validation examples:
teacher_tokenizer.decode(eval_ds[0]['input_ids'], skip_special_tokens=True)
pprint(eval_examples[0])
This looks good, so let's move on to fine-tuning!
Fine-tuning Transformers for extractive question answering involves a significant amount of postprocessing to map the model's logits to spans of text for the predicted answers. Again we'll use the custom trainer from transformerlab.question_answering
to do this for us. First we need to specify the training arguments:
batch_size = 16
logging_steps = len(train_ds) // batch_size
teacher_args = QuestionAnsweringTrainingArguments(
output_dir="checkpoints",
evaluation_strategy = "epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=2,
weight_decay=0.01,
logging_steps=logging_steps)
print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_examples.num_rows}")
print(f"Logging steps: {logging_steps}")
Next we instantiate the trainer:
def teacher_init():
return AutoModelForQuestionAnswering.from_pretrained(teacher_model_checkpoint)
teacher_trainer = QuestionAnsweringTrainer(
model_init=teacher_init,
args=teacher_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
eval_examples=eval_examples,
tokenizer=teacher_tokenizer)
and fine-tune (this takes around 2h on a single 16GB GPU):
teacher_trainer.train();
Finally, we save the model so we can upload it to the HuggingFace model hub:
teacher_trainer.save_model('models/bert-base-uncased-finetuned-squad-v1')
Now that we've fine-tuned BERT-base on SQuAD, we can easily evaluate it by downloading from the Hub and initialising a new trainer:
teacher_checkpoint = "lewtun/bert-base-uncased-finetuned-squad-v1"
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_checkpoint)
teacher_finetuned = AutoModelForQuestionAnswering.from_pretrained(teacher_checkpoint)
teacher_trainer = QuestionAnsweringTrainer(
model=teacher_finetuned,
args=teacher_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
eval_examples=eval_examples,
tokenizer=teacher_tokenizer)
teacher_trainer.evaluate()
These scores are within 1% of the values quoted in the DistilBERT paper, and probably due to slightly different choices of the hyperparameters. Let's move on to fine-tuning DistilBERT!
Here we follow the same steps as we did for BERT-base, beginning with tokenizing the datasets:
distilbert_checkpoint = "distilbert-base-uncased"
distilbert_tokenizer = AutoTokenizer.from_pretrained(distilbert_checkpoint)
num_train_examples = len(squad_ds['train'])
num_eval_examples = len(squad_ds['validation'])
train_ds, eval_ds, eval_examples = convert_examples_to_features(squad_ds, distilbert_tokenizer, num_train_examples, num_eval_examples)
Next, we configure and initialise the trainer:
batch_size = 16
logging_steps = len(train_ds) // batch_size
distilbert_args = QuestionAnsweringTrainingArguments(
output_dir="checkpoints",
evaluation_strategy = "epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=3,
weight_decay=0.01,
logging_steps=logging_steps,
disable_tqdm=False
)
print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_examples.num_rows}")
print(f"Logging steps: {logging_steps}")
def distilbert_init():
return AutoModelForQuestionAnswering.from_pretrained(distilbert_checkpoint)
distilbert_trainer = QuestionAnsweringTrainer(
model_init=distilbert_init,
args=distilbert_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
eval_examples=eval_examples,
tokenizer=distilbert_tokenizer
)
Fine-tuning takes about 1.5h on a single 16GB GPU:
distilbert_trainer.train();
distilbert_trainer.save_model('models/bert-base-uncased-finetuned-squad-v1')
The main thing we need to implement task-specific distillation is augment the standard cross-entropy loss with a distillation term (see above equation). We can implement this by overriding the compute_loss
method of the QuestionAnsweringTrainer
, but first let's define the training arguments we'll need:
class DistillationTrainingArguments(QuestionAnsweringTrainingArguments):
def __init__(self, *args, alpha_ce=0.5, alpha_distil=0.5, temperature=2.0, **kwargs):
super().__init__(*args, **kwargs)
self.alpha_ce = alpha_ce
self.alpha_distil = alpha_distil
self.temperature = temperature
For the trainer, we'll need a few ingredients:
- We need two models (a teacher and student), and since the
model
attribute is the one that is optimized, we'll just add an attribute for the teacher - When we pass the question and context to the student or teacher, we get a range of scores (logits) for the start and end positions. Since we want to minimize the distance between the teacher and student predictions , we'll use the KL-divergence as our distillation loss
- Once the distillation loss is computed, we take a linear combination with the cross-entropy to obtain our final loss function
The following code does the trick:
class DistillationTrainer(QuestionAnsweringTrainer):
def __init__(self, *args, teacher_model=None, **kwargs):
super().__init__(*args, **kwargs)
self.teacher = teacher_model
self.teacher.eval()
self.train_dataset.set_format(
type=self.train_dataset.format["type"], columns=list(self.train_dataset.features.keys()))
def compute_loss(self, model, inputs):
inputs_stu = {
"input_ids": inputs['input_ids'],
"attention_mask": inputs['attention_mask'],
"start_positions": inputs['start_positions'],
"end_positions": inputs['end_positions'],
}
outputs_stu = model(**inputs_stu)
loss = outputs_stu.loss
start_logits_stu = outputs_stu.start_logits
end_logits_stu = outputs_stu.end_logits
with torch.no_grad():
outputs_tea = self.teacher(
input_ids=inputs["input_ids"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
start_logits_tea = outputs_tea.start_logits
end_logits_tea = outputs_tea.end_logits
assert start_logits_tea.size() == start_logits_stu.size()
assert end_logits_tea.size() == end_logits_stu.size()
loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_start = (loss_fct(
F.log_softmax(start_logits_stu / self.args.temperature, dim=-1),
F.softmax(start_logits_tea / self.args.temperature, dim=-1)) * (self.args.temperature ** 2))
loss_end = (loss_fct(
F.log_softmax(end_logits_stu / self.args.temperature, dim=-1),
F.softmax(end_logits_tea / self.args.temperature, dim=-1)) * (self.args.temperature ** 2))
loss_logits = (loss_start + loss_end) / 2.0
loss = self.args.alpha_distil * loss_logits + self.args.alpha_distil * loss
return loss
It's then a similar process to configure and initialise the trainer:
batch_size = 16
logging_steps = len(train_ds) // batch_size
student_training_args = DistillationTrainingArguments(
output_dir=f"checkpoints",
evaluation_strategy = "epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=3,
weight_decay=0.01,
logging_steps=logging_steps,
)
print(f"Number of training examples: {train_ds.num_rows}")
print(f"Number of validation examples: {eval_ds.num_rows}")
print(f"Number of raw validation examples: {eval_examples.num_rows}")
print(f"Logging steps: {logging_steps}")
teacher_checkpoint = "lewtun/bert-base-uncased-finetuned-squad-v1"
student_checkpoint = "distilbert-base-uncased"
teacher_model = AutoModelForQuestionAnswering.from_pretrained(teacher_checkpoint).to(device)
student_model = AutoModelForQuestionAnswering.from_pretrained(student_checkpoint).to(device)
student_tokenizer = AutoTokenizer.from_pretrained(student_checkpoint)
distil_trainer = DistillationTrainer(
model=student_model,
teacher_model=teacher_model,
args=student_training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
eval_examples=eval_examples,
tokenizer=student_tokenizer)
distil_trainer.train()
distil_trainer.save_model('models/distilbert-base-uncased-distilled-squad-v1')
As a simple benchmark, here we compare the time it takes for our teacher and student to generate 1,000 predictions on a CPU (to simulate a production environment). First, we load our fine-tuned models:
student_model_ckpt = 'lewtun/distilbert-base-uncased-distilled-squad-v1'
teacher_model_ckpt = 'lewtun/bert-base-uncased-finetuned-squad-v1'
student_tokenizer = AutoTokenizer.from_pretrained(student_model_ckpt)
student_model = AutoModelForQuestionAnswering.from_pretrained(student_model_ckpt).to('cpu')
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_ckpt)
teacher_model = AutoModelForQuestionAnswering.from_pretrained(teacher_model_ckpt).to('cpu')
Next we create two pipelines for the student and teacher:
student_pipe = QuestionAnsweringPipeline(student_model, student_tokenizer)
teacher_pipe = QuestionAnsweringPipeline(teacher_model, teacher_tokenizer)
And then run the inference test:
%%time
for idx in range(1000):
context = squad_ds['validation'][idx]['context']
question = squad_ds['validation'][idx]['question']
teacher_pipe(question=question, context=context)
%%time
for idx in range(1000):
context = squad_ds['validation'][idx]['context']
question = squad_ds['validation'][idx]['question']
student_pipe(question=question, context=context)
From this example, we see roughly a 2x speedup from using a distilled model with less than 3% drop in Exact Match / F1-score!