# Writing a multistage pipeline: BERT Fine-tuning + Distillation¶

It is common to want to train a model on a particular task, and reuse that model or part of the model in a fine-tuning stage on a different dataset. Flambé allows users to link directly to objects in a previous stage in the pipeline without having to run two different experiments (more information on linking here)

For this tutorial, we look at a recent use case in natural language processing, namely fine-tuning a BERT model on a text classification task, and applying knowledge distillation on that model in order to obtain a smaller model of high performance. Knowledge distillation is interesting as BERT is relativly slow, which can hinder its use in production systems.

## First step: BERT fine-tuning¶

We start by taking a pretrained BERT encoder, and we fine-tune it on the SSTDataset by adding a linear output layer on top of the encoder. We start with the dataset, and apply a special TextField object which can load the pretrained vocabulary learned by BERT.

The SSTDataset below inherits from our TabularDataset component. This object takes as input a transform dictionary, where you can specify Field objects. A Field is considered a featurizer: it can take an arbitrary number of columns and return an any number of features.

Tip

You are free to completely override the Dataset object and not use Field, as long as you follow its interface: Dataset.

In this example, we apply a PretrainedTransformerField and a LabelField.

dataset: !SSTDataset
transform:
text: !PretrainedTransformerField
alias: 'bert-base-uncased'
label: !LabelField


Tip

By default, fields are aligned with the input columns, but one can also make an explicit mapping if more than one feature should be created from the same column:

transform:
text:
columns: 0
field: !PretrainedTransformerField
alias: 'bert-base-uncased'
label:
columns: 1
field: !LabelField


Next we define our model. We use the TextClassifier object, which takes an Embedder, and an output layer. Here, we use the PretrainedTransformerEmbedder

teacher: !TextClassifier

embedder: !PretrainedTransformerEmbedder
pool: True

output_layer: !SoftmaxLayer
input_size: !@ teacher[embedder].hidden_size
output_size: !@ dataset.label.vocab_size  # We link the to size of the label space


Finally we put all of this in a Trainer object, which will execute training.

Note

Recall that you can’t link to parent objects because they won’t be initialized yet; that’s why we link directly to the embedder via bracket notation (it will be initialized because it’s above in the config and not a parent), and access the intended hidden_size attribute

Tip

Any component can be specified at the top level in the pipeline or be an argument to another Component objects. A Component has a run method which for many objects consists of just a pass statement, meaning that using them at the top level is equivalent to declaring them. The Trainer however executes training through its run method, and will therefore be both declared and executed.

finetune: !Trainer
dataset: !@ dataset
train_sampler: !BaseSampler
batch_size: 16
val_sampler: !BaseSampler
batch_size: 16
model: !@ teacher
loss_fn: !torch.NLLLoss
metric_fn: !Accuracy
params: !@ finetune[model].trainable_params
lr: 0.00005


## Second step: Knowledge distillation¶

We now introduce a second model, which we will call the student model:

student: !TextClassifier

embedder: !Embedder
embedding: !Embeddings
num_embeddings: !@dataset.text.vocab_size
embedding_dim: 300
encoder: !PooledRNNEncoder
input_size: 300
rnn_type: sru
n_layers: 2
hidden_size: 256
pooling: !LastPooling
output_layer: !SoftmaxLayer
input_size: !@ student[embedder][encoder].hidden_size
output_size: !@ dataset.label.vocab_size


Attention

Note how this new model is way less complex than the original layer, being more appropriate for productions systems.

In the above example, we decided to reuse the same embedding layer, which allows us not to have to provide a new Field to the dataset. However, you may also decide to perform different preprocessing for the student model:

dataset: !SSTDataset
transform:
teacher_text: !PretrainedTransformerField
alias: 'bert-base-uncased'
lower: true
label: !LabelField
student_text: !TextField


We can now proceed to the final step of our pipeline which is the DistillationTrainer. The key here is to link to the teacher model that was obtained in the finetune stage above.

Tip

You can specify to the DistillationTrainer which columns of the dataset to pass to the teacher model, and which to pass to the student model through the teacher_columns and student_columns arguments.

distill: !DistillationTrainer
dataset: !@ dataset
train_sampler: !BaseSampler
batch_size: 16
val_sampler: !BaseSampler
batch_size: 16
teacher_model: !@ finetune.model
student_model: !@ student
loss_fn: !torch.NLLLoss
metric_fn: !Accuracy
params: !@ distill[student_model].trainable_params
lr: 0.00005
alpha_kl: 0.5
temperature: 1


Attention

Linking to the teacher model directly would use the model pre-finetuning, so we link to the model inside the finetune stage. Note that for these links to work, it’s important for the Trainer object to have the model as instance attribute.

That’s it! You can find the full configuration below.

## Full configuration¶

!Experiment

name: fine-tune-bert-then-distill
pipeline:

dataset: !SSTDataset
transform:
text: !PretrainedTransformerField
alias: 'bert-base-uncased'
label: !LabelField

teacher: !TextClassifier
embedder: !PretrainedTransformerEmbedder
alias: 'bert-base-uncased'
pool: True
output_layer: !SoftmaxLayer
input_size: !@ teacher[embedder].hidden_size
output_size: !@ dataset.label.vocab_size  # We link the to size of the label space

student: !TextClassifier
embedder: !Embedder
embedding: !Embeddings
num_embeddings: !@ dataset.text.vocab_size
embedding_dim: 300
encoder: !PooledRNNEncoder
input_size: 300
rnn_type: sru
n_layers: 2
hidden_size: 256
pooling: last
output_layer: !SoftmaxLayer
input_size: !@ student[embedder][encoder].hidden_size
output_size: !@ dataset.label.vocab_size

finetune: !Trainer
dataset: !@ dataset
train_sampler: !BaseSampler
batch_size: 16
val_sampler: !BaseSampler
batch_size: 16
model: !@ teacher
loss_fn: !torch.NLLLoss
metric_fn: !Accuracy
params: !@ finetune[model].trainable_params
lr: 0.00005

distill: !DistillationTrainer
dataset: !@ dataset
train_sampler: !BaseSampler
batch_size: 16
val_sampler: !BaseSampler
batch_size: 16
teacher_model: !@ finetune.model
student_model: !@ student
loss_fn: !torch.NLLLoss
metric_fn: !Accuracy