Writing a multistage pipeline: BERT Fine-tuning + Distillation

It is common to want to train a model on a particular task, and reuse that model or part of the model in a fine-tuning stage on a different dataset. Flambé allows users to link directly to objects in a previous stage in the pipeline without having to run two different experiments (more information on linking here)

For this tutorial, we look at a recent use case in natural language processing, namely fine-tuning a BERT model on a text classification task, and applying knowledge distillation on that model in order to obtain a smaller model of high performance. Knowledge distillation is interesting as BERT is relativly slow, which can hinder its use in production systems.

First step: BERT fine-tuning

We start by taking a pretrained BERT encoder, and we fine-tune it on the SSTDataset by adding a linear output layer on top of the encoder. We start with the dataset, and apply a special TextField object which can load the pretrained vocabulary learned by BERT.

The SSTDataset below inherits from our TabularDataset component. This object takes as input a transform dictionary, where you can specify Field objects. A Field is considered a featurizer: it can take an arbitrary number of columns and return an any number of features.

Tip

You are free to completely override the Dataset object and not use Field, as long as you follow its interface: Dataset.

In this example, we apply a PretrainedTransformerField and a LabelField.

dataset: !SSTDataset
    transform:
        text: !PretrainedTransformerField
            alias: 'bert-base-uncased'
        label: !LabelField

Tip

By default, fields are aligned with the input columns, but one can also make an explicit mapping if more than one feature should be created from the same column:

transform:
    text:
        columns: 0
        field: !PretrainedTransformerField
            alias: 'bert-base-uncased'
    label:
        columns: 1
        field: !LabelField

Next we define our model. We use the TextClassifier object, which takes an Embedder, and an output layer. Here, we use the PretrainedTransformerEmbedder

teacher: !TextClassifier

  embedder: !PretrainedTransformerEmbedder
    pool: True

  output_layer: !SoftmaxLayer
    input_size: !@ teacher[embedder].hidden_size
    output_size: !@ dataset.label.vocab_size  # We link the to size of the label space

Finally we put all of this in a Trainer object, which will execute training.

Note

Recall that you can’t link to parent objects because they won’t be initialized yet; that’s why we link directly to the embedder via bracket notation (it will be initialized because it’s above in the config and not a parent), and access the intended hidden_size attribute

Tip

Any component can be specified at the top level in the pipeline or be an argument to another Component objects. A Component has a run method which for many objects consists of just a pass statement, meaning that using them at the top level is equivalent to declaring them. The Trainer however executes training through its run method, and will therefore be both declared and executed.

finetune: !Trainer
  dataset: !@ dataset
  train_sampler: !BaseSampler
    batch_size: 16
  val_sampler: !BaseSampler
    batch_size: 16
  model: !@ teacher
  loss_fn: !torch.NLLLoss
  metric_fn: !Accuracy
  optimizer: !AdamW
    params: !@ finetune[model].trainable_params
    lr: 0.00005

Second step: Knowledge distillation

We now introduce a second model, which we will call the student model:

student: !TextClassifier

  embedder: !Embedder
    embedding: !Embeddings
      num_embeddings: !@dataset.text.vocab_size
      embedding_dim: 300
    encoder: !PooledRNNEncoder
      input_size: 300
      rnn_type: sru
      n_layers: 2
      hidden_size: 256
    pooling: !LastPooling
  output_layer: !SoftmaxLayer
    input_size: !@ student[embedder][encoder].hidden_size
    output_size: !@ dataset.label.vocab_size

Attention

Note how this new model is way less complex than the original layer, being more appropriate for productions systems.

In the above example, we decided to reuse the same embedding layer, which allows us not to have to provide a new Field to the dataset. However, you may also decide to perform different preprocessing for the student model:

dataset: !SSTDataset
  transform:
    teacher_text: !PretrainedTransformerField
      alias: 'bert-base-uncased'
      lower: true
    label: !LabelField
    student_text: !TextField

We can now proceed to the final step of our pipeline which is the DistillationTrainer. The key here is to link to the teacher model that was obtained in the finetune stage above.

Tip

You can specify to the DistillationTrainer which columns of the dataset to pass to the teacher model, and which to pass to the student model through the teacher_columns and student_columns arguments.

distill: !DistillationTrainer
  dataset: !@ dataset
  train_sampler: !BaseSampler
    batch_size: 16
  val_sampler: !BaseSampler
    batch_size: 16
  teacher_model: !@ finetune.model
  student_model: !@ student
  loss_fn: !torch.NLLLoss
  metric_fn: !Accuracy
  optimizer: !torch.Adam
    params: !@ distill[student_model].trainable_params
    lr: 0.00005
  alpha_kl: 0.5
  temperature: 1

Attention

Linking to the teacher model directly would use the model pre-finetuning, so we link to the model inside the finetune stage. Note that for these links to work, it’s important for the Trainer object to have the model as instance attribute.

That’s it! You can find the full configuration below.

Full configuration

!Experiment

name: fine-tune-bert-then-distill
pipeline:

  dataset: !SSTDataset
    transform:
      text: !PretrainedTransformerField
        alias: 'bert-base-uncased'
      label: !LabelField

  teacher: !TextClassifier
    embedder: !PretrainedTransformerEmbedder
      alias: 'bert-base-uncased'
      pool: True
    output_layer: !SoftmaxLayer
      input_size: !@ teacher[embedder].hidden_size
      output_size: !@ dataset.label.vocab_size  # We link the to size of the label space

  student: !TextClassifier
    embedder: !Embedder
      embedding: !Embeddings
        num_embeddings: !@ dataset.text.vocab_size
        embedding_dim: 300
      encoder: !PooledRNNEncoder
        input_size: 300
        rnn_type: sru
        n_layers: 2
        hidden_size: 256
      pooling: last
    output_layer: !SoftmaxLayer
      input_size: !@ student[embedder][encoder].hidden_size
      output_size: !@ dataset.label.vocab_size

  finetune: !Trainer
    dataset: !@ dataset
    train_sampler: !BaseSampler
      batch_size: 16
    val_sampler: !BaseSampler
      batch_size: 16
    model: !@ teacher
    loss_fn: !torch.NLLLoss
    metric_fn: !Accuracy
    optimizer: !AdamW
      params: !@ finetune[model].trainable_params
      lr: 0.00005

  distill: !DistillationTrainer
    dataset: !@ dataset
    train_sampler: !BaseSampler
      batch_size: 16
    val_sampler: !BaseSampler
      batch_size: 16
    teacher_model: !@ finetune.model
    student_model: !@ student
    loss_fn: !torch.NLLLoss
    metric_fn: !Accuracy
    optimizer: !torch.Adam
      params: !@ distill[student_model].trainable_params
      lr: 0.00005
    alpha_kl: 0.5
    temperature: 1