Slicing PyTorch Datasets
I wanted to run some experiments with Victor Sanh's implementation of movement pruning so that I could compare against a custom Trainer
I had implemented. Since each epoch of training on SQuAD takes around 2 hours on a single GPU, I wanted to speed-up the comparison by prune-tuning on a subset of the data.
Since it's been a while that I've worked directly with PyTorch Dataset
objects,1 I'd forgotten that one can't use a naive slicing of the dataset. For example, the following will fail:
from torch.utils.data import RandomSampler, DataLoader
train_ds = ...
sample_ds = train_ds[:10] # folly!
sample_sampler = RandomSampler(sample_ds)
sample_dl = DataLoader(sample_ds, sampler=sample_sampler, batch_size=4)
next(iter(sample_dl)) # KeyError or similar :(
The reason this occurs is because slicing train_ds
will return an object of a different type to Dataset
(e.g. a dict
), so the RandomSampler
doesn't know how to produce appropriate samples for the DataLoader
.
The solution I ended up with is to use the Subset
class to create the desired subset:
from torch.utils.data import RandomSampler, DataLoader, Subset
train_ds = ...
num_train_samples = 100
sample_ds = Subset(train_dataset, np.arange(num_train_samples))
sample_sampler = RandomSampler(sample_ds)
sample_dl = DataLoader(sample_ds, sampler=sample_sampler, batch_size=4)
next(iter(sample_dl))
To see this in action, we'll use the IMDB dataset as an example. First let's download and unpack the dataset:
!wget -nc http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz -P data
!tar -xf data/aclImdb_v1.tar.gz -C data/
Following the transformers
docs, the next thing we need is to read the samples and labels. The following code does the trick:
from pathlib import Path
DATA = Path('data/aclImdb')
def read_imdb_split(split_dir):
split_dir = Path(split_dir)
texts = []
labels = []
for label_dir in ["pos", "neg"]:
for text_file in (split_dir/label_dir).iterdir():
texts.append(text_file.read_text())
labels.append(0 if label_dir == "neg" else 1)
return texts, labels
train_texts, train_labels = read_imdb_split(f'{DATA}/train')
# peek at first sample and label
train_texts[0], train_labels[0]
Next we need to tokenize the texts, which can be done as follows:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
Finally we can define a custom Dataset
object:
import torch
class IMDbDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
train_ds = IMDbDataset(train_encodings, train_labels)
Each element of train_ds
is a dict
with keys corresponding to the inputs expected in the forward
pass of a Transformer model like BERT. If we take a slice, then we get tensors for each of the keys:
train_ds[:10]
This dict
type is not suitable for sampling from, so the solution is to wrap our Dataset
with Subset
as follows:
import numpy as np
from torch.utils.data import Subset
num_train_examples = 100
sample_ds = Subset(train_ds, np.arange(num_train_examples))
assert len(sample_ds) == num_train_examples
As a sanity check, let's compare the raw text against the decoded examples in the dataset:
tokenizer.decode(sample_ds[0]['input_ids'], skip_special_tokens=True)
This looks good, how about the last example?
print(tokenizer.decode(sample_ds[-1]['input_ids'], skip_special_tokens=True), "\n")
print(train_texts[99])
The final step is to define the sampler and dataloader and we're done!
from torch.utils.data import RandomSampler, DataLoader
sample_sampler = RandomSampler(sample_ds)
sample_dl = DataLoader(sample_ds, sampler=train_sampler, batch_size=4)
next(iter(sample_dl))
1. Mostly because I've been corrupted by the datasets
and fastai
APIs↩