# Tutorial 2: Finetuning Bert for Sequence Classification using a LoRA adapter

When we import a pretrained transformer model from HuggingFace, we receive the encoder/decoder weights, which aren't that useful on their own - to perform a useful task such as sequence classification, we add a classification head on top of the model and train those weights on the required dataset. In this tutorial, we'll look at fine tuning a Bert model for sequence classification with two approaches. First, we'll attempt full Supervised Fine Tuning (SFT). Then, we'll use the Mase stack to add a [LoRA](https://arxiv.org/abs/2106.09685) adapter to the model. We'll look at the effect in memory requirement for training and the achieved accuracy.

In [17]:
checkpoint = "prajjwal1/bert-tiny"
tokenizer_checkpoint = "bert-base-uncased"
dataset_name = "imdb"

## Sentiment Analysis with the IMDb Dataset

The IMDB dataset, introduced in [this 2011 paper](https://aclanthology.org/P11-1015/) from Stanford, is commonly used for sentiment analysis in the Natural Language Processing (NLP) community. This is a collection of 50k movie reviews from the IMDb website, labelled as either "positive" or "negative". Here is an example of a positive review:

> I turned over to this film in the middle of the night and very nearly skipped right passed it. It was only because there was nothing else on that I decided to watch it. In the end, I thought it was great.<br /><br />An interesting storyline, good characters, a clever script and brilliant directing makes this a fine film to sit down and watch. This was, in fact, the first I'd heard of this movie, but I would have been happy to have paid money to see this at the cinema.<br /><br />My IMDB Rating : 8 out of 10<br /><br />

And a negative review:

> its a totally average film with a few semi-alright action sequences that make the plot seem a little better and remind the viewer of the classic van dam films. parts of the plot don't make sense and seem to be added in to use up time. the end plot is that of a very basic type that doesn't leave the viewer guessing and any twists are obvious from the beginning. the end scene with the flask backs don't make sense as they are added in and seem to have little relevance to the history of van dam's character. not really worth watching again, bit disappointed in the end production, even though it is apparent it was shot on a low budget certain shots and sections in the film are of poor directed quality

The dataset is available from HuggingFace through the ``datasets`` library. We use the `get_tokenized_dataset` utility in Mase to automatically tokenize it.

In [18]:
from chop.tools import get_tokenized_dataset

dataset, tokenizer = get_tokenized_dataset(
    dataset=dataset_name,
    checkpoint=tokenizer_checkpoint,
    return_tokenizer=True,
)

[32mINFO    [0m [34mTokenizing dataset imdb with AutoTokenizer for bert-base-uncased.[0m
Using the latest cached version of the dataset since imdb couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /Users/yz10513/.cache/huggingface/datasets/imdb/plain_text/0.0.0/e6281661ce1c48d982bc483cf8a173c1bbeb5d31 (last modified on Sun Dec  1 05:38:09 2024).


## Generate a MaseGraph with Custom Arguments

By inspecting the implementation of the Bert model in HuggingFace, we can see the forward function has a signature similar to the following.

```python
    class BertForSequenceClassification(BertPreTrainedModel):
        def __init__(self, config):
            super().__init__(config)
            self.bert = BertModel(config)
            ...

        def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            token_type_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            head_mask: Optional[torch.Tensor] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
        ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
            ...
```

By default, the MaseGraph constructor chooses to use the `input_ids` argument, ignoring the other optional arguments. However, you can specify which inputs to drive during symbolic tracing using the `hf_input_names` argument. In the following cell, we also drive the `attention_mask` and `labels` inputs. By specifying the `labels` argument, we include a `nn.CrossEntropyLoss` module at the end of the model to calculate the loss directly.

> **Task:** Remove the `attention_mask` and `labels` arguments from the `hf_input_names` list and re-run the following cell. Use `mg.draw()` to visualize the graph in each case. Can you see any changes in the graph topology? Can you explain why this happens?

In [19]:
from transformers import AutoModelForSequenceClassification

from chop import MaseGraph
import chop.passes as passes

model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.config.problem_type = "single_label_classification"

mg = MaseGraph(
    model,
    hf_input_names=[
        "input_ids",
        "attention_mask",
        "labels",
    ],
)

mg, _ = passes.init_metadata_analysis_pass(mg)
mg, _ = passes.add_common_metadata_analysis_pass(mg)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.
[32mINFO    [0m [34mGetting dummy input for prajjwal1/bert-tiny.[0m


tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154,  102],
        [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877,  102]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154,  102],
        [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877,  102]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])
tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],




## Full Supervised Finetuning (SFT)

Before training the model, let's inspect how many trainable parameters there are. If you're familiar with Keras, you might have used the `model.summary()` API before, but it's not as easy to do the same in Pytorch - luckily, Mase has a module-level pass with this functionality.

In [20]:
from chop.passes.module import report_trainable_parameters_analysis_pass

_, _ = report_trainable_parameters_analysis_pass(mg.model)

+-------------------------------------------------+------------------------+
| Submodule                                       |   Trainable Parameters |
| bert                                            |                4385920 |
+-------------------------------------------------+------------------------+
| bert.embeddings                                 |                3972864 |
+-------------------------------------------------+------------------------+
| bert.embeddings.word_embeddings                 |                3906816 |
+-------------------------------------------------+------------------------+
| bert.embeddings.token_type_embeddings           |                    256 |
+-------------------------------------------------+------------------------+
| bert.embeddings.position_embeddings             |                  65536 |
+-------------------------------------------------+------------------------+
| bert.embeddings.LayerNorm                       |                    256 |

From this, we can see the majority of the trainable parameters are in the `Embedding` layer. We don't need to train this, so we freeze those parameters in the cell below.

In [21]:
for param in mg.model.bert.embeddings.parameters():
    param.requires_grad = False

To train the model, we rely on the `Trainer` class from the `transformers` library, which makes it easy to set up a training loop with any hardware configuration. The `get_trainer` utility in Mase handles assigning the training arguments to the `Trainer` class for common use cases, such as in this tutorial.

In [22]:
from chop.tools import get_trainer

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)

Before running any fine tuning, let's see how the model performs out of the box. Without any fine-tuning, we can see the model just performs a random guess - there are two labels in the dataset, so this corresponds to an accuracy of around 50%.

In [23]:
# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

100%|██████████| 3125/3125 [01:07<00:00, 46.59it/s]

Evaluation accuracy: 0.52164





Now, run the cell below to execute a single training epoch with the current setup.

In [None]:
trainer.train()

Let's see how much accuracy we get after a single training epoch of full finetuning.

In [None]:
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

We can now export the SFT version of the model to be used in later tutorials.

In [None]:
from pathlib import Path

mg.export(f"{Path.home()}/tutorial_2_sft")

## Parameter Efficient Finetuning (PEFT) with LoRA

An alternative to full fine-tuning is Parameter Efficient Fine Tuning (PEFT), which uses a small number of trainable parameters to achieve similar performance. LoRA was proposed by a research team at Microsoft in 2021, as an efficient technique for PEFT. 

<div style="text-align: center;">
    <img src="imgs/lora_adapter.png" alt="drawing" width="400"/>
</div>

Consider the standard equation of a linear layer:

$$
y = X W + b
$$

The LoRA method involves replacing this with the following, where A and B are low-rank matrices. We freeze the $W$ parameters, and only allow the optimizer to train the parameters in $A$ and $B$.

$$
y = X (W + AB) + b
$$

This enables us to achieve accuracies comparable to full fine tuning, while only training a fraction of the parameters. See [the paper](https://arxiv.org/abs/2106.09685) for more details. We can inject the LoRA adapter into the existing model using the `insert_lora_adapter_transform_pass` pass in Mase, as follows.

In [10]:
mg, _ = passes.insert_lora_adapter_transform_pass(
    mg,
    pass_args={
        "rank": 6,
        "alpha": 1.0,
        "dropout": 0.5,
    },
)

[32mINFO    [0m [34mReplaced node: bert_encoder_layer_0_attention_self_query, target: bert.encoder.layer.0.attention.self.query with LoRALinear module.[0m
[32mINFO    [0m [34mReplaced node: bert_encoder_layer_0_attention_self_key, target: bert.encoder.layer.0.attention.self.key with LoRALinear module.[0m
[32mINFO    [0m [34mReplaced node: bert_encoder_layer_0_attention_self_value, target: bert.encoder.layer.0.attention.self.value with LoRALinear module.[0m
[32mINFO    [0m [34mReplaced node: bert_encoder_layer_0_attention_output_dense, target: bert.encoder.layer.0.attention.output.dense with LoRALinear module.[0m
[32mINFO    [0m [34mReplaced node: bert_encoder_layer_0_intermediate_dense, target: bert.encoder.layer.0.intermediate.dense with LoRALinear module.[0m
[32mINFO    [0m [34mReplaced node: bert_encoder_layer_0_output_dense, target: bert.encoder.layer.0.output.dense with LoRALinear module.[0m
[32mINFO    [0m [34mReplaced node: bert_encoder_layer_1_attenti

Similar to before, let's report the number of trainable parameters.

In [11]:
_, _ = report_trainable_parameters_analysis_pass(mg.model)

+-----------------------------------------------------+------------------------+
| Submodule                                           |   Trainable Parameters |
| bert                                                |                 439808 |
+-----------------------------------------------------+------------------------+
| bert.embeddings                                     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.word_embeddings                     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.token_type_embeddings               |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.position_embeddings                 |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.LayerNorm 

In this case, LoRA reduces the number of trainable parameters by $4.5\times$! We'll run a few more training epochs and evaluate the resulting accuracy.

In [12]:
trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
    num_train_epochs=1,
)
trainer.train()

# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

 16%|█▌        | 500/3125 [00:40<03:29, 12.52it/s]

{'loss': 0.6402, 'grad_norm': 2.730722665786743, 'learning_rate': 4.2e-05, 'epoch': 0.16}


 32%|███▏      | 1000/3125 [01:13<03:24, 10.41it/s]

{'loss': 0.5216, 'grad_norm': 5.168735504150391, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.32}


 48%|████▊     | 1501/3125 [01:44<01:20, 20.05it/s]

{'loss': 0.4751, 'grad_norm': 13.205789566040039, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.48}


 64%|██████▍   | 2001/3125 [02:11<01:44, 10.79it/s]

{'loss': 0.4376, 'grad_norm': 14.12922191619873, 'learning_rate': 1.8e-05, 'epoch': 0.64}


 80%|████████  | 2500/3125 [02:37<00:24, 25.37it/s]

{'loss': 0.4233, 'grad_norm': 4.498361587524414, 'learning_rate': 1e-05, 'epoch': 0.8}


 96%|█████████▌| 3002/3125 [03:01<00:10, 11.99it/s]

{'loss': 0.4164, 'grad_norm': 8.473535537719727, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.96}


100%|██████████| 3125/3125 [03:07<00:00, 16.69it/s]


{'train_runtime': 187.2829, 'train_samples_per_second': 133.488, 'train_steps_per_second': 16.686, 'train_loss': 0.48366960693359373, 'epoch': 1.0}


100%|██████████| 3125/3125 [01:08<00:00, 45.64it/s]

Evaluation accuracy: 0.8218





After training is finished, we can run the `fuse_lora_weights_transform_pass` pass to optimize the model for inference. This pass replaces each `LoRALinear` instance with an `nn.Linear` module, where the $AB$ product added to the original weights matrix. This incurs less kernel invocations when deploying the model, which reduces inference runtime.

In [15]:
mg, _ = passes.fuse_lora_weights_transform_pass(mg)
eval_results = trainer.evaluate()

100%|██████████| 3125/3125 [01:26<00:00, 36.31it/s]


In [16]:
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

Evaluation accuracy: 0.8218


## Conclusion

Finally, export the finetuned model to be used in future tutorials.

In [14]:
from pathlib import Path

mg.export(f"{Path.home()}/tutorial_2_lora")

[32mINFO    [0m [34mExporting MaseGraph to /Users/yz10513/tutorial_2_lora.pt, /Users/yz10513/tutorial_2_lora.mz[0m
[32mINFO    [0m [34mExporting GraphModule to /Users/yz10513/tutorial_2_lora.pt[0m
[32mINFO    [0m [34mExporting MaseMetadata to /Users/yz10513/tutorial_2_lora.mz[0m
