I wanted to run some experiments with Victor Sanh's implementation of movement pruning so that I could compare against a custom Trainer I had implemented. Since each epoch of training on SQuAD takes around 2 hours on a single GPU, I wanted to speed-up the comparison by prune-tuning on a subset of the data.

Since it's been a while that I've worked directly with PyTorch Dataset objects,1 I'd forgotten that one can't use a naive slicing of the dataset. For example, the following will fail:

from torch.utils.data import RandomSampler, DataLoader

train_ds = ...
sample_ds = train_ds[:10] # folly!
sample_sampler = RandomSampler(sample_ds)
next(iter(sample_dl)) # KeyError or similar :(


The reason this occurs is because slicing train_ds will return an object of a different type to Dataset (e.g. a dict), so the RandomSampler doesn't know how to produce appropriate samples for the DataLoader.

The solution I ended up with is to use the Subset class to create the desired subset:

from torch.utils.data import RandomSampler, DataLoader, Subset

train_ds = ...
num_train_samples = 100
sample_ds = Subset(train_dataset, np.arange(num_train_samples))
sample_sampler = RandomSampler(sample_ds)
next(iter(sample_dl))


## A simple example

To see this in action, we'll use the IMDB dataset as an example. First let's download and unpack the dataset:

!wget -nc http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz -P data
!tar -xf data/aclImdb_v1.tar.gz -C data/


Following the transformers docs, the next thing we need is to read the samples and labels. The following code does the trick:

from pathlib import Path

DATA = Path('data/aclImdb')

split_dir = Path(split_dir)
texts = []
labels = []
for label_dir in ["pos", "neg"]:
for text_file in (split_dir/label_dir).iterdir():
labels.append(0 if label_dir == "neg" else 1)

return texts, labels

# peek at first sample and label
train_texts[0], train_labels[0]

('For a movie that gets no respect there sure are a lot of memorable quotes listed for this gem. Imagine a movie where Joe Piscopo is actually funny! Maureen Stapleton is a scene stealer. The Moroni character is an absolute scream. Watch for Alan "The Skipper" Hale jr. as a police Sgt.',
1)

Next we need to tokenize the texts, which can be done as follows:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')


Finally we can define a custom Dataset object:

import torch

class IMDbDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels

def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item

def __len__(self):
return len(self.labels)

train_ds = IMDbDataset(train_encodings, train_labels)


Each element of train_ds is a dict with keys corresponding to the inputs expected in the forward pass of a Transformer model like BERT. If we take a slice, then we get tensors for each of the keys:

train_ds[:10]

{'input_ids': tensor([[  101,  2005,  1037,  ...,     0,     0,     0],
[  101, 13576,  5469,  ...,     0,     0,     0],
[  101,  1037,  5024,  ...,     0,     0,     0],
...,
[  101,  2023,  2001,  ...,     0,     0,     0],
[  101,  2081,  2044,  ...,  3286,  1011,   102],
[  101,  2005,  1037,  ...,     0,     0,     0]]),
'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
[1, 1, 1,  ..., 0, 0, 0],
[1, 1, 1,  ..., 0, 0, 0],
...,
[1, 1, 1,  ..., 0, 0, 0],
[1, 1, 1,  ..., 1, 1, 1],
[1, 1, 1,  ..., 0, 0, 0]]),
'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}

This dict type is not suitable for sampling from, so the solution is to wrap our Dataset with Subset as follows:

import numpy as np
from torch.utils.data import Subset

num_train_examples = 100
sample_ds = Subset(train_ds, np.arange(num_train_examples))
assert len(sample_ds) == num_train_examples


As a sanity check, let's compare the raw text against the decoded examples in the dataset:

tokenizer.decode(sample_ds[0]['input_ids'], skip_special_tokens=True)

'for a movie that gets no respect there sure are a lot of memorable quotes listed for this gem. imagine a movie where joe piscopo is actually funny! maureen stapleton is a scene stealer. the moroni character is an absolute scream. watch for alan " the skipper " hale jr. as a police sgt.'

This looks good, how about the last example?

print(tokenizer.decode(sample_ds[-1]['input_ids'], skip_special_tokens=True), "\n")
print(train_texts[99])

beautiful film, pure cassavetes style. gena rowland gives a stunning performance of a declining actress, dealing with success, aging, loneliness... and alcoholism. she tries to escape her own subconscious ghosts, embodied by the death spectre of a young girl. acceptance of oneself, of human condition, though its overall difficulties, is the real purpose of the film. the parallel between the theatrical sequences and the film itself are puzzling : it's like if the stage became a way out for the heroin. if all american movies could only be that top - quality, dealing with human relations on an adult level, not trying to infantilize and standardize feelings... one of the best dramas ever. 10 / 10.

Beautiful film, pure Cassavetes style. Gena Rowland gives a stunning performance of a declining actress, dealing with success, aging, loneliness...and alcoholism. She tries to escape her own subconscious ghosts, embodied by the death spectre of a young girl. Acceptance of oneself, of human condition, though its overall difficulties, is the real purpose of the film. The parallel between the theatrical sequences and the film itself are puzzling: it's like if the stage became a way out for the Heroin. If all american movies could only be that top-quality, dealing with human relations on an adult level, not trying to infantilize and standardize feelings... One of the best dramas ever. 10/10.


The final step is to define the sampler and dataloader and we're done!

from torch.utils.data import RandomSampler, DataLoader

sample_sampler = RandomSampler(sample_ds)

{'input_ids': tensor([[  101, 13576,  5469,  ...,     0,     0,     0],
'labels': tensor([1, 1, 1])}
1. Mostly because I've been corrupted by the datasets and fastai APIs