Writing a multistage pipeline: BERT Fine-tuning + Distillation

It is common to want to train a model on a particular task, and reuse that model or part of the model in a fine-tuning stage on a different dataset. Flambé allows users to link directly to objects in a previous stage in the pipeline without having to run two different experiments (more information on linking here)

For this tutorial, we look at a recent use case in natural language processing, namely fine-tuning a BERT model on a text classification task, and applying knowledge distillation on that model in order to obtain a smaller model of high performance. Knowledge distillation is interesting as BERT is relativly slow, which can hinder its use in production systems.

First step: BERT fine-tuning

We start by taking a pretrained BERT encoder, and we fine-tune it on the SSTDataset by adding a linear output layer on top of the encoder. We start with the dataset, and apply a special TextField object which can load the pretrained vocabulary learned by BERT.

The SSTDataset below inherits from our TabularDataset component. This object takes as input a transform dictionary, where you can specify Field objects. A Field is considered a featurizer: it can take an arbitrary number of columns and return an any number of features.

Tip

You are free to completely override the Dataset object and not use Field, as long as you follow its interface: Dataset.

In this example, we apply a BERTTextField and a LabelField.

dataset: !SSTDataset
    transform:
        text: !BERTTextField.from_alias
            alias: 'bert-base-uncased'
            lower: true
        label: !LabelField

Tip

By default, fields are aligned with the input columns, but one can also make an explicit mapping if more than one feature should be created from the same column:

transform:
    text:
        columns: 0
        field: !BERTTextField.from_alias
            alias: 'bert-base-uncased'
            lower: true
    label:
        columns: 1
        field: !LabelField

Next we define our model. We use the TextClassifier object, which takes an Embedder, and an output layer:

teacher: !TextClassifier

  embedder: !Embedder
    embedding: !BERTEmbeddings.from_alias
      path: 'bert-base-uncased'
      embedding_freeze: True
    encoder: !BERTEncoder.from_alias
      path: 'bert-base-uncased'
      pool_last: True

  output_layer: !SoftmaxLayer
    input_size: !@ model.embedder.encoder.config.hidden_size
    output_size: !@ dataset.label.vocab_size  # We link the to size of the label space

Finally we put all of this in a Trainer object, which will execute training.

Tip

Any component can be specified at the top level in the pipeline or be an argument
to another Component objects. A Component has a run method which for many objects consists of just a pass statement, meaning that using them at the top level is equivalent to declaring them. The Trainer however executes training through its run method, and will therefore be both declared and executed.
finetune: !Trainer

dataset: !@ dataset train_sampler: !BaseSampler

batch_size: 16
val_sampler: !BaseSampler
batch_size: 16

model: !@ teacher loss_fn: !torch.NLLLoss metric_fn: !Accuracy optimizer: !AdamW

params: !@ finetune.model.trainable_params lr: 0.00005

Second step: Knowledge distillation

We now introduce a second model, which we will call the student model:

student: !TextClassifier

  embedder: !Embedder
    embedding: !BERTEmbeddings.from_alias
      path: 'bert-base-uncased'
      embedding_freeze: True
    encoder: !PooledRNNEncoder
      rnn_type: sru
      n_layers: 2
      hidden_size: 256
      pooling: last
  output_layer: !SoftmaxLayer
    input_size: !@ student.embedder.encoder.hidden_size
    output_size: !@ dataset.label.vocab_size

Attention

Note how this new model is way less complex than the original layer, being more appropriate for productions systems.

In the above example, we decided to reuse the same embedding layer, which allows us not to have to provide a new Field to the dataset. However, you may also decide to perform different preprocessing for the student model:

dataset: !SSTDataset
    transform:
        teacher_text: !BERTTextField.from_alias
            alias: 'bert-base-uncased'
            lower: true
        label: !LabelField
        student_text: !TextField

We can now proceed to the final step of our pipeline which is the DistillationTrainer. The key here is to link to the teacher model that was obtained in the finetune stage above.

Tip

You can specify to the DistillationTrainer which columns of the dataset to pass to the teacher model, and which to pass to the student model through the teacher_columns and student_columns arguments.

distill: !DistillationTrainer
  dataset: !@ dataset
  train_sampler: !BaseSampler
    batch_size: 16
  val_sampler: !BaseSampler
    batch_size: 16
  teacher_model: !@ finetune.model
  student_model: !@ student
  loss_fn: !torch.NLLLoss
  metric_fn: !Accuracy
  optimizer: !torch.Adam
    params: !@ distill.student_model.trainable_params
    lr: 0.00005
  alpha_kl: 0.5
  temperature: 1

Attention

Linking to the teacher model directly would use the model pre-finetuning, so we link to the model inside the finetune stage. Note that for these links to work, it’s important for the Trainer object to have the model as instance attribute.

That’s it! You can find the full configuration below.

Full configuration

dataset: !SSTDataset
    transform:
        text: !BERTTextField.from_alias
            alias: 'bert-base-uncased'
            lower: true
        label: !LabelField

teacher: !TextClassifier
  embedder: !Embedder
    embedding: !BERTEmbeddings.from_alias
      path: 'bert-base-uncased'
      embedding_freeze: True
    encoder: !BERTEncoder.from_alias
      path: 'bert-base-uncased'
      pool_last: True
  output_layer: !SoftmaxLayer
    input_size: !@ model.embedder.encoder.config.hidden_size
    output_size: !@ dataset.label.vocab_size  # We link the to size of the label space

student: !TextClassifier
  embedder: !Embedder
    embedding: !BERTEmbeddings.from_alias
      path: 'bert-base-uncased'
      embedding_freeze: True
    encoder: !PooledRNNEncoder
      rnn_type: sru
      n_layers: 2
      hidden_size: 256
      pooling: last
  output_layer: !SoftmaxLayer
    input_size: !@ student.embedder.encoder.hidden_size
    output_size: !@ dataset.label.vocab_size

finetune: !Trainer
  dataset: !@ dataset
  train_sampler: !BaseSampler
    batch_size: 16
  val_sampler: !BaseSampler
    batch_size: 16
  model: !@ teacher
  loss_fn: !torch.NLLLoss
  metric_fn: !Accuracy
  optimizer: !AdamW
    params: !@ finetune.model.trainable_params
    lr: 0.00005

distill: !DistillationTrainer
  dataset: !@ dataset
  train_sampler: !BaseSampler
    batch_size: 16
  val_sampler: !BaseSampler
    batch_size: 16
  teacher_model: !@ finetune.model
  student_model: !@ student
  loss_fn: !torch.NLLLoss
  metric_fn: !Accuracy
  optimizer: !torch.Adam
    params: !@ distill.student_model.trainable_params
    lr: 0.00005
  alpha_kl: 0.5
  temperature: 1