Masked classes for pruning BERT-like Transformers
The classes in this module are adapted from Victor Sanh's implementation of Movement Pruning: Adaptive Sparsity by Fine-Tuning in the examples/research_projects
of the transformers
repository. The main changes are as follows:
- To make these classes compatible with v4 of
transformers
we have replaced all instances ofBertLayerNorm
withtorch.nn.LayerNorm
- In the
forward
method ofTopKBinarizer
, we check whetherthreshold
is a float or a list, since the latter occurs in thedatasets
format.
Masked versions of BERT
In order to compute the adaptive mask during pruning, a special set of "masked" BERT classes is required to account for sparsity in the model's weights. François Lagunas is planning to release a generic version of these classes around March 2021, so we should consider using those when they become available.