Recently, Sylvain Gugger from HuggingFace has created some nice tutorials on using transformers for text classification and named entity recognition. One trick that caught my attention was the use of a data collator in the trainer, which automatically pads the model inputs in a batch to the length of the longest example. This bypasses the need to set a global maximum sequence length, and in practice leads to faster training since we perform fewer redundant computations on the padded tokens and attention masks.

I wanted to use a data collator for both training and error analysis (e.g. by inspecting the top losses of the model). One problem: during training, each batch is collated on the fly so how do I pad my inputs in subsequent Dataset.map operations?

For sequence classification tasks, the solution I ended up with was to simply grab the data collator from the trainer and use it in my post-processing functions:

data_collator = trainer.data_collator

def processing_function(batch):
    # pad inputs
    batch = data_collator(batch)
    ...
    return batch

For token classification tasks, there is a dedicated DataCollatorForTokenClassification which expects a list of dicts, where each dict represents a single example in the dataset. Since a Dataset slice returns a dict of lists, we need a two more lines to wrangle the data in the expected format:

from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(trainer.tokenizer)

def processing_function(batch):
    # convert dict of lists to list of dicts
    features = [dict(zip(batch, t)) for t in zip(*batch.values())]
    # pad inputs and labels
    batch = data_collator(features)
    ...
    return batch

For an end-to-end example, let's grab 1,000 examples from the IMDB dataset:

from datasets import load_dataset

imdb = (load_dataset('imdb', split='train')
        .train_test_split(train_size=800, test_size=200))
imdb
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 800
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 200
    })
})

Next, let's load a pretrained model and its corresponding tokenizer:

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

num_labels = 2
model_name = 'distilbert-base-cased'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = (AutoModelForSequenceClassification
         .from_pretrained(model_name, num_labels=num_labels)
         .to(device))

Before fine-tuning the model, we need to tokenize and encode the dataset, so let's do that with a simple Dataset.map operation:

def tokenize_and_encode(batch): 
    return tokenizer(batch['text'], truncation=True)

imdb_enc = imdb.map(tokenize_and_encode, batched=True)
imdb_enc
DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'label', 'text'],
        num_rows: 800
    })
    test: Dataset({
        features: ['attention_mask', 'input_ids', 'label', 'text'],
        num_rows: 200
    })
})

The final step is to define the metrics

import numpy as np
from datasets import load_metric

accuracy_score = load_metric("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_score.compute(predictions=predictions, references=labels)

the arguments for the trainer

from transformers import TrainingArguments

batch_size = 16
logging_steps = len(imdb_enc['train']) // batch_size

training_args = TrainingArguments(
    output_dir="results",
    num_train_epochs=1,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    evaluation_strategy="epoch",
    disable_tqdm=False,
    logging_steps=logging_steps)

and the trainer itself:

Important: The trainer will remove in-place any dataset columns of str type, so in this example imdb_enc loses the text column.

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=imdb_enc['train'],
    eval_dataset=imdb_enc['test'],
    tokenizer=tokenizer)

trainer.train();
[50/50 00:32, Epoch 1/1]
Epoch Training Loss Validation Loss Accuracy
1 0.390015 0.328747 0.875000

</div> </div> </div> </div> </div>

By default, the Trainer class uses the simple default_data_collator to collate batches of dict-like objects, but by passing the tokenizer we get a DataCollatorWithPadding instead:

data_collator = trainer.data_collator
type(data_collator)
transformers.data.data_collator.DataCollatorWithPadding

To see how this collator works, let's pass a dummy batch and observe that both the input_ids and attention_mask are padded as expected:

batch = {'input_ids': [[0,1,2], [0,1,2,3,4,5]]}
data_collator(batch)
{'input_ids': tensor([[0, 1, 2, 0, 0, 0],
        [0, 1, 2, 3, 4, 5]]), 'attention_mask': tensor([[1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1]])}

Finally, we can calculate the loss per example with the following function:1

def loss_per_example(batch):
    batch = data_collator(batch)
    input_ids = torch.tensor(batch["input_ids"], device=device)
    attention_mask = torch.tensor(batch["attention_mask"], device=device)
    labels = torch.tensor(batch["labels"], device=device)

    with torch.no_grad():
        output = model(input_ids, attention_mask)
        batch["predicted_label"] = torch.argmax(output.logits, axis=1)

    loss = torch.nn.functional.cross_entropy(
        output.logits, labels, reduction="none")
    batch["loss"] = loss
    
    # datasets requires list of NumPy array data types
    for k, v in batch.items():
        batch[k] = v.cpu().numpy()

    return batch


losses_ds = imdb_enc['test'].map(
    loss_per_example, batched=True, batch_size=batch_size)

It's then a simple matter to convert losses_ds to a pandas.DataFrame and sort by loss to find the examples where the model is most confused:

import pandas as pd
pd.set_option("display.max_colwidth", None)

losses_ds.set_format('pandas')
losses_df = losses_ds[:][['label', 'predicted_label', 'loss']]
# add the text column removed by the trainer
losses_df['text'] = imdb['test']['text']
losses_df.sort_values("loss", ascending=False).head()
label predicted_label loss text
147 1 0 3.477502 Was the script more fitting for a 30 minute sitcom? Yes, but they still make it work! I thought the actors did a fantastic job with an otherwise bland script, especially Jack Black and Christopher Walken. Most people on the board seem to really hate this film. I personally can't see how that could be, but Envy is just one of those film that you either love it or hate it. Much like Napoleon Dynamite and every Leslie Neilsen movie ever made. You either think it's one of the worst movies ever made or one of the funniest. Don't avoid this movie because of the reviews. Watch it and see if you're one of the ones who really like it! If you do, I guarantee it's worth your money. If you don't like it... well, now you know.
143 1 0 2.925410 I would just like to say, that no matter how low budget the film is, it needs to be shown throughout this world the point to these movies. We don't read that much anymore, instead people want to see movies. Having this series out on DVD, has made me want to read the whole series, and want more. PLEASE MAKE ALL 8 MOVIES. Please don't change any of the characters either, it ruins the effect. Because I have grown to love the actors who have played the characters. PLEASE MAKE ALL 8 MOVIES. I want to see the message, and watch the message that these books and now movies are here to portray. We don't get that enough anymore. AWESOME JOB!!!
57 0 1 2.873445 I like Brad Pitt enormously. He is an actor with brains and wit, not to mention face, pectorals and all the rest. Since I saw him in "Thelma and Louise" a thought has been bothering me, who does he remind me of? "Troy" did it for me. He is the new Brigitte Bardot. The differences are obvious of course. Male, American etc but Brigitte Bardot comes to mind nonetheless. He is so beautiful that he is at his most effective when he plays against it. "Kalifornia" "12 Monkeys" "Fight Club" "Snatch" His self deprecating humor makes him human, almost accessible. Fortunately "Troy" will soon be forgotten. Only still photographs with Pitt, semi naked in ravishing sprint positions will decorate the walls of legions of salivating fans. Strange, "Das Boot" is one of the great films of the second part of the 20th Century. What is Wolfgang Petersen doing directing this? Well, I suppose it would be very hard to say no at the chance of working with the new Brigitte Bardot.
151 1 0 2.861723 SOLDIER is not as bad as many have made it out to be. I found the film to have some of the sacarstic, cynical humour like that in Paul Verhoven's Starship Troopers. The lack of dialogue and over the top action is deliberate and adds to the comic-book atmosphere.<br /><br />One particular trivia-bit stands out for me - Todd has the names of several space-war campaigns tattoo'd onto his chest and one of these battles is TANNHAUSER GATE. For the oblivious ones out there, Tannhauser Gate is mentioned in Roy Batty's elegiac last lines in Blade Runner. To imagine that Todd could have fought alongside android troops like Roy is mind boggling to say the least. Maybe script writer David Peoples was nostalgic?<br /><br />I'll give this one 3 out of 5.
53 0 1 2.849806 Reed Diamond plays a man suffering from amnesia who's been in a mental asylum for over a decade after he was found wondering the back roads with blood on his hands. The doctors want to test out an experimental new drug that'll return his lost memories if it works. But when the drugs give him hallucinations of a demon, he chooses to escape instead. While outside he befriends a young boy whose stepfather (Greg Grunberg) mistreats his mother, won't let her near the darkroom in his basement & acts suspicious in general.<br /><br />While the general 'mystery' of the film is a tad easy to identify way before it's revealed, I found Mr. Diamond's acting to be enthralling enough to keep my attention throughout. (In the interest of full disclosure, I've been a huge fan of his since Homicide and his brief, but extremely pivotal, role in The Shield up through Journeyman & Dollhouse) Not a great film nor a good one, but serviceable enough. Although I did like it better than the previous films that I've seen from Director/writer Michael Hurst (Room 6, Pumkinhead 4, Mansquito)<br /><br />Eye Candy: one fleeting pair of boobs in a hallucination<br /><br />My Grade: C-

1. The non-padded version of this function is adapted from an implementation by Leandro von Werra.

</div>