Lately, I've been using the transformers trainer together with the datasets library and I was a bit mystified by the disappearence of some columns in the training and validation sets after fine-tuning. It wasn't until I saw Sylvain Gugger's tutorial on question answering that I realised this is by design! Indeed, as noted in the docs1 for the train_dataset and eval_dataset arguments of the Trainer:

If it is an datasets.Dataset, columns not accepted by the model.forward() method are automatically removed.

A simple one-liner to restore the missing columns is the following:

dataset.set_format(type=dataset.format["type"], columns=list(dataset.features.keys()))

To understand why this works, we can peek inside the relevant Trainer code

??Trainer._remove_unused_columns
Signature:
Trainer._remove_unused_columns(
    self,
    dataset:'datasets.Dataset',
    description:Union[str, NoneType]=None,
)
Docstring: <no docstring>
Source:
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return
        # Inspect model forward signature to keep only the arguments it accepts.
        signature = inspect.signature(self.model.forward)
        signature_columns = list(signature.parameters.keys())
        # Labels may be named label or label_ids, the default data collator handles that.
        signature_columns += ["label", "label_ids"]
        columns = [k for k in signature_columns if k in dataset.column_names]
        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
        dset_description = "" if description is None else f"in the {description} set "
        logger.info(
            f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
        )
        dataset.set_format(type=dataset.format["type"], columns=columns)
File:      /usr/local/lib/python3.6/dist-packages/transformers/trainer.py
Type:      function

and see that we're effectively undoing the final dataset.set_format() operation.

A simple example

To see this in action, let's grab 1,000 examples from the COLA dataset:

from datasets import load_dataset

cola = load_dataset('glue', 'cola', split='train[:1000]')
cola
Dataset({
    features: ['sentence', 'label', 'idx'],
    num_rows: 1000
})

Here we can see that each split has three Dataset.features: sentence, label, and idx. By inspecting the Dataset.format attribute

cola.format
{'type': None,
 'format_kwargs': {},
 'columns': ['idx', 'label', 'sentence'],
 'output_all_columns': False}

we also see that the type is None. Next, let's load a pretrained model and its corresponding tokenizer:

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

num_labels = 2
model_name = 'distilbert-base-uncased'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = (AutoModelForSequenceClassification
         .from_pretrained(model_name, num_labels=num_labels)
         .to(device))

Before fine-tuning the model, we need to tokenize and encode the dataset, so let's do that with a simple Dataset.map operation:

def tokenize_and_encode(batch):
    return tokenizer(batch['sentence'], truncation=True)

cola_enc = cola.map(tokenize_and_encode, batched=True)
cola_enc
Dataset({
    features: ['attention_mask', 'idx', 'input_ids', 'label', 'sentence'],
    num_rows: 1000
})

Note that the encoding process has added two new Dataset.features to our dataset: attention_mask and input_ids. Since we don't care about evaluation, let's create a minimal trainer and fine-tune the model for one epoch:

from transformers import TrainingArguments, Trainer

batch_size = 16
logging_steps = len(cola_enc) // batch_size

training_args = TrainingArguments(
    output_dir="results",
    num_train_epochs=1,
    per_device_train_batch_size=batch_size,
    disable_tqdm=False,
    logging_steps=logging_steps)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=cola_enc,
    tokenizer=tokenizer)

trainer.train();
[63/63 00:03, Epoch 1/1]
Step Training Loss
62 0.630255

By inspecting one of the training examples

cola_enc[0]
{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'input_ids': [101,
  2256,
  2814,
  2180,
  1005,
  1056,
  4965,
  2023,
  4106,
  1010,
  2292,
  2894,
  1996,
  2279,
  2028,
  2057,
  16599,
  1012,
  102],
 'label': 1}

it seems that we've lost our sentence and idx columns! However, by inspecting the features attribute

cola_enc.features
{'attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'idx': Value(dtype='int32', id=None),
 'input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None),
 'sentence': Value(dtype='string', id=None)}

we see that they are still present in the dataset. Applying our one-liner to restore them gives the desired result:

cola_enc.set_format(type=cola_enc.format["type"], columns=list(cola_enc.features.keys()))
cola_enc[0]
{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'idx': 0,
 'input_ids': [101,
  2256,
  2814,
  2180,
  1005,
  1056,
  4965,
  2023,
  4106,
  1010,
  2292,
  2894,
  1996,
  2279,
  2028,
  2057,
  16599,
  1012,
  102],
 'label': 1,
 'sentence': "Our friends won't buy this analysis, let alone the next one we propose."}

1. Proof positive that I only read documentation after some threshold of confusion.↩