Masked classes for pruning BERT-like Transformers
The classes in this module are adapted from Victor Sanh's implementation of Movement Pruning: Adaptive Sparsity by Fine-Tuning in the examples/research_projects of the transformers repository. The main changes are as follows:
- To make these classes compatible with v4 of
transformerswe have replaced all instances ofBertLayerNormwithtorch.nn.LayerNorm - In the
forwardmethod ofTopKBinarizer, we check whetherthresholdis a float or a list, since the latter occurs in thedatasetsformat.
Masked versions of BERT
In order to compute the adaptive mask during pruning, a special set of "masked" BERT classes is required to account for sparsity in the model's weights. François Lagunas is planning to release a generic version of these classes around March 2021, so we should consider using those when they become available.