Tutorial 2: Finetuning Bert for Sequence Classification using a LoRA adapter#

When we import a pretrained transformer model from HuggingFace, we receive the encoder/decoder weights, which aren’t that useful on their own - to perform a useful task such as sequence classification, we add a classification head on top of the model and train those weights on the required dataset. In this tutorial, we’ll look at fine tuning a Bert model for sequence classification with two approaches. First, we’ll attempt full Supervised Fine Tuning (SFT). Then, we’ll use the Mase stack to add a LoRA adapter to the model. We’ll look at the effect in memory requirement for training and the achieved accuracy.

checkpoint = "prajjwal1/bert-tiny"
tokenizer_checkpoint = "bert-base-uncased"
dataset_name = "imdb"

Sentiment Analysis with the IMDb Dataset#

The IMDB dataset, introduced in this 2011 paper from Stanford, is commonly used for sentiment analysis in the Natural Language Processing (NLP) community. This is a collection of 50k movie reviews from the IMDb website, labelled as either “positive” or “negative”. Here is an example of a positive review:

I turned over to this film in the middle of the night and very nearly skipped right passed it. It was only because there was nothing else on that I decided to watch it. In the end, I thought it was great.

An interesting storyline, good characters, a clever script and brilliant directing makes this a fine film to sit down and watch. This was, in fact, the first I’d heard of this movie, but I would have been happy to have paid money to see this at the cinema.

My IMDB Rating : 8 out of 10

And a negative review:

its a totally average film with a few semi-alright action sequences that make the plot seem a little better and remind the viewer of the classic van dam films. parts of the plot don’t make sense and seem to be added in to use up time. the end plot is that of a very basic type that doesn’t leave the viewer guessing and any twists are obvious from the beginning. the end scene with the flask backs don’t make sense as they are added in and seem to have little relevance to the history of van dam’s character. not really worth watching again, bit disappointed in the end production, even though it is apparent it was shot on a low budget certain shots and sections in the film are of poor directed quality

The dataset is available from HuggingFace through the datasets library. We use the get_tokenized_dataset utility in Mase to automatically tokenize it.

from chop.tools import get_tokenized_dataset

dataset, tokenizer = get_tokenized_dataset(
    dataset=dataset_name,
    checkpoint=tokenizer_checkpoint,
    return_tokenizer=True,
)
INFO     Tokenizing dataset imdb with AutoTokenizer for bert-base-uncased.
Using the latest cached version of the dataset since imdb couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /Users/yz10513/.cache/huggingface/datasets/imdb/plain_text/0.0.0/e6281661ce1c48d982bc483cf8a173c1bbeb5d31 (last modified on Sun Dec  1 05:38:09 2024).

Generate a MaseGraph with Custom Arguments#

By inspecting the implementation of the Bert model in HuggingFace, we can see the forward function has a signature similar to the following.

    class BertForSequenceClassification(BertPreTrainedModel):
        def __init__(self, config):
            super().__init__(config)
            self.bert = BertModel(config)
            ...

        def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            token_type_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            head_mask: Optional[torch.Tensor] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
        ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
            ...

By default, the MaseGraph constructor chooses to use the input_ids argument, ignoring the other optional arguments. However, you can specify which inputs to drive during symbolic tracing using the hf_input_names argument. In the following cell, we also drive the attention_mask and labels inputs. By specifying the labels argument, we include a nn.CrossEntropyLoss module at the end of the model to calculate the loss directly.

Task: Remove the attention_mask and labels arguments from the hf_input_names list and re-run the following cell. Use mg.draw() to visualize the graph in each case. Can you see any changes in the graph topology? Can you explain why this happens?

from transformers import AutoModelForSequenceClassification

from chop import MaseGraph
import chop.passes as passes

model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.config.problem_type = "single_label_classification"

mg = MaseGraph(
    model,
    hf_input_names=[
        "input_ids",
        "attention_mask",
        "labels",
    ],
)

mg, _ = passes.init_metadata_analysis_pass(mg)
mg, _ = passes.add_common_metadata_analysis_pass(mg)
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.
INFO     Getting dummy input for prajjwal1/bert-tiny.
/Users/yz10513/anaconda3/envs/mase/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154,  102],
        [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877,  102]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154,  102],
        [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877,  102]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])
tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
tensor([[[ 0.7973,  0.0109, -8.8405,  ...,  1.4170,  0.1046, -0.1551],
         [-1.1766,  1.2879, -1.0986,  ...,  0.4749, -0.5899,  0.8746],
         [-2.0560,  0.7748, -0.8909,  ..., -0.4034,  0.5352, -1.3657],
         ...,
         [ 0.2317, -0.7896,  0.9634,  ..., -0.8037,  0.4834, -0.5868],
         [ 0.0243, -1.0235, -1.2771,  ..., -2.2378,  1.8530,  0.1558],
         [-1.3637,  0.7055, -0.2177,  ...,  0.3557, -0.3971, -0.3107]],

        [[ 0.7973,  0.0109, -8.8405,  ...,  1.4170,  0.1046, -0.1551],
         [-2.6940,  0.6198, -0.4564,  ..., -1.4367, -1.5705, -3.1260],
         [-1.7524,  0.8535, -0.2155,  ..., -0.5222, -1.2430, -1.7199],
         ...,
         [-0.0347,  0.7446,  1.4462,  ..., -1.1578, -2.6197,  0.2612],
         [ 2.4334, -0.3068,  0.8250,  ...,  0.1475,  0.1790,  2.2907],
         [-1.3637,  0.7055, -0.2177,  ...,  0.3557, -0.3971, -0.3107]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[[ 9.7740e-01,  2.5481e-03, -5.2921e-01,  ...,  1.4757e-01,
           1.8900e-01,  2.8282e-01],
         [-3.5020e-02, -6.1047e-02, -1.0465e-01,  ..., -8.1892e-01,
           1.1978e+00,  2.1808e+00],
         [ 4.3982e-01, -1.9602e+00, -6.8830e-01,  ..., -6.3025e-01,
          -1.5967e-01,  1.3284e+00],
         ...,
         [ 1.4783e+00,  1.0907e-01, -1.5222e+00,  ..., -3.0983e-01,
          -1.2971e-01,  1.1265e+00],
         [ 1.5890e+00, -1.6859e+00,  7.8703e-01,  ..., -1.3173e+00,
           2.2258e-01,  8.8157e-01],
         [-3.7517e-01,  1.5191e+00, -2.6796e-01,  ..., -1.6159e+00,
           7.2677e-02,  1.1724e-01]],

        [[ 9.7740e-01,  2.5481e-03, -5.2921e-01,  ...,  1.4757e-01,
           1.8900e-01,  2.8282e-01],
         [-4.4781e-01, -7.9224e-01, -2.1741e+00,  ..., -5.9181e-01,
           1.4373e+00,  2.4267e+00],
         [-2.5942e-01,  9.7163e-01, -3.2928e+00,  ..., -5.9773e-01,
          -3.0482e-01,  1.4038e+00],
         ...,
         [ 5.1574e-02,  3.5218e-01, -3.8926e-01,  ..., -1.1508e+00,
           7.5490e-01,  8.2911e-01],
         [ 1.6107e+00,  6.8170e-02,  9.2537e-01,  ..., -1.5233e+00,
          -6.0733e-01,  3.3097e-01],
         [-3.7517e-01,  1.5191e+00, -2.6796e-01,  ..., -1.6159e+00,
           7.2677e-02,  1.1724e-01]]], grad_fn=<ViewBackward0>)
tensor([[[ 9.7740e-01,  2.5481e-03, -5.2921e-01,  ...,  1.4757e-01,
           1.8900e-01,  2.8282e-01],
         [-3.5020e-02, -6.1047e-02, -1.0465e-01,  ..., -8.1892e-01,
           1.1978e+00,  2.1808e+00],
         [ 4.3982e-01, -1.9602e+00, -6.8830e-01,  ..., -6.3025e-01,
          -1.5967e-01,  1.3284e+00],
         ...,
         [ 1.4783e+00,  1.0907e-01, -1.5222e+00,  ..., -3.0983e-01,
          -1.2971e-01,  1.1265e+00],
         [ 1.5890e+00, -1.6859e+00,  7.8703e-01,  ..., -1.3173e+00,
           2.2258e-01,  8.8157e-01],
         [-3.7517e-01,  1.5191e+00, -2.6796e-01,  ..., -1.6159e+00,
           7.2677e-02,  1.1724e-01]],

        [[ 9.7740e-01,  2.5481e-03, -5.2921e-01,  ...,  1.4757e-01,
           1.8900e-01,  2.8282e-01],
         [-4.4781e-01, -7.9224e-01, -2.1741e+00,  ..., -5.9181e-01,
           1.4373e+00,  2.4267e+00],
         [-2.5942e-01,  9.7163e-01, -3.2928e+00,  ..., -5.9773e-01,
          -3.0482e-01,  1.4038e+00],
         ...,
         [ 5.1574e-02,  3.5218e-01, -3.8926e-01,  ..., -1.1508e+00,
           7.5490e-01,  8.2911e-01],
         [ 1.6107e+00,  6.8170e-02,  9.2537e-01,  ..., -1.5233e+00,
          -6.0733e-01,  3.3097e-01],
         [-3.7517e-01,  1.5191e+00, -2.6796e-01,  ..., -1.6159e+00,
           7.2677e-02,  1.1724e-01]]], grad_fn=<ViewBackward0>)
tensor([[[[ 9.7740e-01,  2.5481e-03, -5.2921e-01,  ..., -5.2614e-01,
           -3.5687e-01, -2.6793e-01],
          [ 2.0335e-01, -5.4534e-01,  3.0686e-01,  ...,  1.4757e-01,
            1.8900e-01,  2.8282e-01]],

         [[-3.5020e-02, -6.1047e-02, -1.0465e-01,  ..., -2.0138e+00,
            4.5529e-01, -7.8171e-01],
          [ 1.1969e+00,  1.6337e+00,  2.5047e-01,  ..., -8.1892e-01,
            1.1978e+00,  2.1808e+00]],

         [[ 4.3982e-01, -1.9602e+00, -6.8830e-01,  ..., -2.2501e-01,
            7.2290e-02, -1.8290e+00],
          [ 8.9952e-01,  1.0029e+00,  7.4520e-04,  ..., -6.3025e-01,
           -1.5967e-01,  1.3284e+00]],

         ...,

         [[ 1.4783e+00,  1.0907e-01, -1.5222e+00,  ...,  1.1867e+00,
           -1.3561e+00,  6.5158e-01],
          [ 9.5466e-01,  4.5887e-01,  7.8078e-01,  ..., -3.0983e-01,
           -1.2971e-01,  1.1265e+00]],

         [[ 1.5890e+00, -1.6859e+00,  7.8703e-01,  ...,  6.5467e-01,
           -6.8451e-01,  6.5081e-01],
          [ 7.0729e-01,  1.4499e+00, -1.5089e-01,  ..., -1.3173e+00,
            2.2258e-01,  8.8157e-01]],

         [[-3.7517e-01,  1.5191e+00, -2.6796e-01,  ...,  3.3130e-01,
           -3.2756e-01, -6.3130e-01],
          [ 8.6773e-01,  2.0996e-01, -3.4332e-01,  ..., -1.6159e+00,
            7.2677e-02,  1.1724e-01]]],


        [[[ 9.7740e-01,  2.5481e-03, -5.2921e-01,  ..., -5.2614e-01,
           -3.5687e-01, -2.6793e-01],
          [ 2.0335e-01, -5.4534e-01,  3.0686e-01,  ...,  1.4757e-01,
            1.8900e-01,  2.8282e-01]],

         [[-4.4781e-01, -7.9224e-01, -2.1741e+00,  ...,  1.7508e+00,
           -3.6708e-01, -1.3251e+00],
          [ 7.9208e-01, -1.3537e-01,  2.3756e-01,  ..., -5.9181e-01,
            1.4373e+00,  2.4267e+00]],

         [[-2.5942e-01,  9.7163e-01, -3.2928e+00,  ..., -9.6646e-01,
           -4.8876e-01, -1.4426e+00],
          [ 1.0250e+00, -6.9093e-01, -1.2734e+00,  ..., -5.9773e-01,
           -3.0482e-01,  1.4038e+00]],

         ...,

         [[ 5.1574e-02,  3.5218e-01, -3.8926e-01,  ..., -1.2252e-02,
            1.0394e+00,  4.2402e-01],
          [-4.7386e-01,  2.6401e+00,  1.7024e+00,  ..., -1.1508e+00,
            7.5490e-01,  8.2911e-01]],

         [[ 1.6107e+00,  6.8170e-02,  9.2537e-01,  ..., -6.1665e-01,
            2.7627e-01, -1.2083e+00],
          [ 9.3395e-01, -9.7541e-01, -2.5442e-02,  ..., -1.5233e+00,
           -6.0733e-01,  3.3097e-01]],

         [[-3.7517e-01,  1.5191e+00, -2.6796e-01,  ...,  3.3130e-01,
           -3.2756e-01, -6.3130e-01],
          [ 8.6773e-01,  2.0996e-01, -3.4332e-01,  ..., -1.6159e+00,
            7.2677e-02,  1.1724e-01]]]], grad_fn=<ViewBackward0>)
tensor([[[-0.1709,  0.5230, -0.8713,  ..., -1.3382,  0.5892,  0.4026],
         [-0.5842,  0.9588,  1.5642,  ..., -1.0731, -0.7330,  0.3132],
         [-0.8601, -1.3756,  0.5042,  ..., -0.0476,  0.2650,  1.2150],
         ...,
         [ 0.0520,  1.1719, -1.5471,  ..., -0.7894,  0.1419,  1.6964],
         [ 0.7654, -1.5053, -0.4142,  ..., -1.4622, -0.8975,  1.4576],
         [-1.2008, -0.6008, -1.4608,  ..., -1.2105, -0.4289,  0.3827]],

        [[-0.1709,  0.5230, -0.8713,  ..., -1.3382,  0.5892,  0.4026],
         [-1.3806,  0.2626, -0.5207,  ..., -1.6714, -0.0554,  1.0225],
         [-1.7116,  1.8788, -2.5695,  ..., -0.6958,  0.5728,  0.5461],
         ...,
         [-1.3246,  1.2196, -0.3034,  ..., -1.1955, -0.6708,  0.5128],
         [ 0.9854,  0.8260,  0.2892,  ..., -0.6428,  0.3637,  0.4339],
         [-1.2008, -0.6008, -1.4608,  ..., -1.2105, -0.4289,  0.3827]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.1709,  0.5230, -0.8713,  ..., -1.3382,  0.5892,  0.4026],
         [-0.5842,  0.9588,  1.5642,  ..., -1.0731, -0.7330,  0.3132],
         [-0.8601, -1.3756,  0.5042,  ..., -0.0476,  0.2650,  1.2150],
         ...,
         [ 0.0520,  1.1719, -1.5471,  ..., -0.7894,  0.1419,  1.6964],
         [ 0.7654, -1.5053, -0.4142,  ..., -1.4622, -0.8975,  1.4576],
         [-1.2008, -0.6008, -1.4608,  ..., -1.2105, -0.4289,  0.3827]],

        [[-0.1709,  0.5230, -0.8713,  ..., -1.3382,  0.5892,  0.4026],
         [-1.3806,  0.2626, -0.5207,  ..., -1.6714, -0.0554,  1.0225],
         [-1.7116,  1.8788, -2.5695,  ..., -0.6958,  0.5728,  0.5461],
         ...,
         [-1.3246,  1.2196, -0.3034,  ..., -1.1955, -0.6708,  0.5128],
         [ 0.9854,  0.8260,  0.2892,  ..., -0.6428,  0.3637,  0.4339],
         [-1.2008, -0.6008, -1.4608,  ..., -1.2105, -0.4289,  0.3827]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-0.1709,  0.5230, -0.8713,  ...,  0.4365,  0.6238, -0.9414],
          [-1.3731,  1.1521,  0.1321,  ..., -1.3382,  0.5892,  0.4026]],

         [[-0.5842,  0.9588,  1.5642,  ..., -1.5431,  0.4999, -1.1350],
          [ 0.9615,  0.8694,  0.0998,  ..., -1.0731, -0.7330,  0.3132]],

         [[-0.8601, -1.3756,  0.5042,  ...,  0.9764, -0.8321, -1.0204],
          [ 1.5175,  1.1454,  0.7791,  ..., -0.0476,  0.2650,  1.2150]],

         ...,

         [[ 0.0520,  1.1719, -1.5471,  ...,  1.9402, -1.1294,  0.4793],
          [ 1.0053,  0.8099,  1.6415,  ..., -0.7894,  0.1419,  1.6964]],

         [[ 0.7654, -1.5053, -0.4142,  ...,  1.7455, -0.7326,  1.5248],
          [ 1.0806,  1.1457,  2.2163,  ..., -1.4622, -0.8975,  1.4576]],

         [[-1.2008, -0.6008, -1.4608,  ...,  2.0905,  1.8849, -1.5708],
          [ 1.9999,  0.3493, -0.8524,  ..., -1.2105, -0.4289,  0.3827]]],


        [[[-0.1709,  0.5230, -0.8713,  ...,  0.4365,  0.6238, -0.9414],
          [-1.3731,  1.1521,  0.1321,  ..., -1.3382,  0.5892,  0.4026]],

         [[-1.3806,  0.2626, -0.5207,  ...,  1.6517, -0.2316, -1.3171],
          [ 0.6812, -0.0090,  0.3803,  ..., -1.6714, -0.0554,  1.0225]],

         [[-1.7116,  1.8788, -2.5695,  ...,  0.4927, -0.4850, -1.0645],
          [ 1.2646,  1.6481,  0.9055,  ..., -0.6958,  0.5728,  0.5461]],

         ...,

         [[-1.3246,  1.2196, -0.3034,  ...,  1.2747,  1.2353,  0.2825],
          [ 1.5373,  0.8648,  0.6062,  ..., -1.1955, -0.6708,  0.5128]],

         [[ 0.9854,  0.8260,  0.2892,  ...,  1.3848, -0.0103, -1.0700],
          [ 1.3827,  2.9809,  0.0276,  ..., -0.6428,  0.3637,  0.4339]],

         [[-1.2008, -0.6008, -1.4608,  ...,  2.0905,  1.8849, -1.5708],
          [ 1.9999,  0.3493, -0.8524,  ..., -1.2105, -0.4289,  0.3827]]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.0123,  0.5761,  0.2209,  ..., -0.1027,  1.1061, -2.5200],
         [-1.1465, -1.5578, -0.6984,  ...,  1.0310,  0.4824, -0.2291],
         [-1.0361, -1.8192, -2.3055,  ...,  1.5286, -1.5941,  1.1762],
         ...,
         [-0.7992,  0.0886,  0.4887,  ..., -1.7941,  0.4835,  1.3780],
         [-1.4692, -0.9135, -0.2802,  ..., -0.9691,  0.3500,  1.8863],
         [-0.5760, -0.0452,  0.4230,  ..., -0.7179, -0.7858,  1.6879]],

        [[-0.0123,  0.5761,  0.2209,  ..., -0.1027,  1.1061, -2.5200],
         [-0.3700, -1.9754, -0.7315,  ...,  0.2293,  0.6996,  3.1299],
         [-0.6252,  0.2879, -1.4036,  ..., -2.0560, -2.4623, -0.9584],
         ...,
         [-1.1306, -1.4343, -1.4422,  ..., -1.6115, -0.0475,  1.3975],
         [-0.9816, -1.4909, -1.0086,  ..., -0.9284,  0.5260,  1.5330],
         [-0.5760, -0.0452,  0.4230,  ..., -0.7179, -0.7858,  1.6879]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.0123,  0.5761,  0.2209,  ..., -0.1027,  1.1061, -2.5200],
         [-1.1465, -1.5578, -0.6984,  ...,  1.0310,  0.4824, -0.2291],
         [-1.0361, -1.8192, -2.3055,  ...,  1.5286, -1.5941,  1.1762],
         ...,
         [-0.7992,  0.0886,  0.4887,  ..., -1.7941,  0.4835,  1.3780],
         [-1.4692, -0.9135, -0.2802,  ..., -0.9691,  0.3500,  1.8863],
         [-0.5760, -0.0452,  0.4230,  ..., -0.7179, -0.7858,  1.6879]],

        [[-0.0123,  0.5761,  0.2209,  ..., -0.1027,  1.1061, -2.5200],
         [-0.3700, -1.9754, -0.7315,  ...,  0.2293,  0.6996,  3.1299],
         [-0.6252,  0.2879, -1.4036,  ..., -2.0560, -2.4623, -0.9584],
         ...,
         [-1.1306, -1.4343, -1.4422,  ..., -1.6115, -0.0475,  1.3975],
         [-0.9816, -1.4909, -1.0086,  ..., -0.9284,  0.5260,  1.5330],
         [-0.5760, -0.0452,  0.4230,  ..., -0.7179, -0.7858,  1.6879]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-0.0123,  0.5761,  0.2209,  ..., -0.1457, -0.7538,  0.1761],
          [-0.0705,  0.9215,  0.7990,  ..., -0.1027,  1.1061, -2.5200]],

         [[-1.1465, -1.5578, -0.6984,  ...,  0.0289, -2.1112, -0.8728],
          [ 0.6506, -1.6966,  1.4463,  ...,  1.0310,  0.4824, -0.2291]],

         [[-1.0361, -1.8192, -2.3055,  ..., -0.2195, -1.1732,  0.3182],
          [-0.5841, -0.0227,  3.0901,  ...,  1.5286, -1.5941,  1.1762]],

         ...,

         [[-0.7992,  0.0886,  0.4887,  ...,  0.7859, -1.0127, -0.2676],
          [-0.3055,  0.6270, -3.0705,  ..., -1.7941,  0.4835,  1.3780]],

         [[-1.4692, -0.9135, -0.2802,  ...,  0.1197, -0.7532,  0.0731],
          [ 0.6096, -1.0893, -0.6959,  ..., -0.9691,  0.3500,  1.8863]],

         [[-0.5760, -0.0452,  0.4230,  ...,  0.8851,  0.3078,  0.8106],
          [-1.1804,  0.9512,  0.3169,  ..., -0.7179, -0.7858,  1.6879]]],


        [[[-0.0123,  0.5761,  0.2209,  ..., -0.1457, -0.7538,  0.1761],
          [-0.0705,  0.9215,  0.7990,  ..., -0.1027,  1.1061, -2.5200]],

         [[-0.3700, -1.9754, -0.7315,  ...,  0.5756, -1.5559,  0.0326],
          [ 1.4229,  2.3970, -0.4516,  ...,  0.2293,  0.6996,  3.1299]],

         [[-0.6252,  0.2879, -1.4036,  ...,  0.5306, -0.5608,  1.1861],
          [-2.5980,  0.2673,  3.3016,  ..., -2.0560, -2.4623, -0.9584]],

         ...,

         [[-1.1306, -1.4343, -1.4422,  ...,  0.3918, -1.5336, -0.5026],
          [ 1.8587,  0.8501, -1.2402,  ..., -1.6115, -0.0475,  1.3975]],

         [[-0.9816, -1.4909, -1.0086,  ...,  0.2956,  0.0351, -1.0685],
          [-0.6594, -0.0133, -1.1863,  ..., -0.9284,  0.5260,  1.5330]],

         [[-0.5760, -0.0452,  0.4230,  ...,  0.8851,  0.3078,  0.8106],
          [-1.1804,  0.9512,  0.3169,  ..., -0.7179, -0.7858,  1.6879]]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-0.5911, -0.4682, -0.4314,  ...,  0.0366, -1.0405, -0.0579],
          [-1.1141, -1.4785, -0.6477,  ...,  0.0631, -1.9950, -0.7912],
          [-0.7059, -0.9954, -1.3923,  ..., -0.1557, -0.9998,  0.2774],
          ...,
          [-0.6300,  0.1252,  0.3486,  ...,  0.5692, -0.9417, -0.1399],
          [-1.0593, -0.6395, -0.3343,  ...,  0.0571, -0.7842,  0.1042],
          [-0.2081,  0.2785,  0.0613,  ..., -0.0353, -0.8389, -0.0026]],

         [[-0.3266,  0.6360,  0.0214,  ...,  0.1677,  0.3883,  0.4382],
          [-0.6299,  0.6012,  0.7379,  ...,  0.2989, -0.1569,  0.9508],
          [-0.4097,  0.6374, -0.1589,  ...,  0.0258, -0.0364,  0.9990],
          ...,
          [-0.0088,  0.1378, -0.4819,  ..., -0.1261,  0.2908,  1.0980],
          [-0.5584,  0.9932, -0.1105,  ...,  0.2486,  0.4005,  0.6046],
          [-0.4516,  0.8092, -0.0513,  ...,  0.1182,  0.4294,  0.4781]]],


        [[[-0.6850, -0.2168, -0.6087,  ...,  0.1402, -0.7267, -0.1502],
          [-0.2084, -0.1325, -0.1151,  ...,  0.0938, -0.8963,  0.1296],
          [-0.3665,  0.3408, -0.6133,  ...,  0.1938, -0.6681,  0.5929],
          ...,
          [-1.0985, -1.2861, -1.2531,  ...,  0.3609, -1.4022, -0.4444],
          [-0.9787, -1.4207, -0.9908,  ...,  0.2989, -0.0159, -1.0060],
          [-0.1970,  0.2371, -0.0115,  ...,  0.0049, -0.8135,  0.1201]],

         [[ 0.0907,  0.7927, -0.0524,  ..., -0.4610, -0.4295,  0.4391],
          [-0.2352,  1.0586, -0.1117,  ..., -0.4943, -0.6546,  0.6617],
          [ 0.8269,  1.7861, -1.1207,  ..., -0.0885,  0.2791,  1.4721],
          ...,
          [-0.1038,  0.9623, -0.7062,  ..., -0.3340, -0.3050,  0.8792],
          [ 0.0507,  0.6634, -0.2642,  ..., -0.4959, -0.7944,  0.7687],
          [ 0.3364,  0.9047, -0.2037,  ..., -0.3705, -0.2893,  0.6787]]]],
       grad_fn=<ScaledDotProductFlashAttentionForCpuBackward0>)
tensor([[[[-0.5911, -0.4682, -0.4314,  ...,  0.0366, -1.0405, -0.0579],
          [-0.3266,  0.6360,  0.0214,  ...,  0.1677,  0.3883,  0.4382]],

         [[-1.1141, -1.4785, -0.6477,  ...,  0.0631, -1.9950, -0.7912],
          [-0.6299,  0.6012,  0.7379,  ...,  0.2989, -0.1569,  0.9508]],

         [[-0.7059, -0.9954, -1.3923,  ..., -0.1557, -0.9998,  0.2774],
          [-0.4097,  0.6374, -0.1589,  ...,  0.0258, -0.0364,  0.9990]],

         ...,

         [[-0.6300,  0.1252,  0.3486,  ...,  0.5692, -0.9417, -0.1399],
          [-0.0088,  0.1378, -0.4819,  ..., -0.1261,  0.2908,  1.0980]],

         [[-1.0593, -0.6395, -0.3343,  ...,  0.0571, -0.7842,  0.1042],
          [-0.5584,  0.9932, -0.1105,  ...,  0.2486,  0.4005,  0.6046]],

         [[-0.2081,  0.2785,  0.0613,  ..., -0.0353, -0.8389, -0.0026],
          [-0.4516,  0.8092, -0.0513,  ...,  0.1182,  0.4294,  0.4781]]],


        [[[-0.6850, -0.2168, -0.6087,  ...,  0.1402, -0.7267, -0.1502],
          [ 0.0907,  0.7927, -0.0524,  ..., -0.4610, -0.4295,  0.4391]],

         [[-0.2084, -0.1325, -0.1151,  ...,  0.0938, -0.8963,  0.1296],
          [-0.2352,  1.0586, -0.1117,  ..., -0.4943, -0.6546,  0.6617]],

         [[-0.3665,  0.3408, -0.6133,  ...,  0.1938, -0.6681,  0.5929],
          [ 0.8269,  1.7861, -1.1207,  ..., -0.0885,  0.2791,  1.4721]],

         ...,

         [[-1.0985, -1.2861, -1.2531,  ...,  0.3609, -1.4022, -0.4444],
          [-0.1038,  0.9623, -0.7062,  ..., -0.3340, -0.3050,  0.8792]],

         [[-0.9787, -1.4207, -0.9908,  ...,  0.2989, -0.0159, -1.0060],
          [ 0.0507,  0.6634, -0.2642,  ..., -0.4959, -0.7944,  0.7687]],

         [[-0.1970,  0.2371, -0.0115,  ...,  0.0049, -0.8135,  0.1201],
          [ 0.3364,  0.9047, -0.2037,  ..., -0.3705, -0.2893,  0.6787]]]],
       grad_fn=<TransposeBackward0>)
tensor([[[-0.9552,  0.6594, -6.5403,  ..., -0.7144,  0.0906,  0.3369],
         [-2.5251,  1.3955, -0.8914,  ..., -2.1363,  0.0271,  1.1132],
         [-3.7148,  0.6796, -0.8710,  ..., -2.6492,  0.5694, -0.1085],
         ...,
         [-2.2403, -0.7594,  0.5414,  ..., -3.0426,  0.8895, -0.0546],
         [-1.6945, -0.6326, -0.8632,  ..., -4.0678,  1.7219,  0.6481],
         [-2.9625,  0.7451, -0.8037,  ..., -2.5048,  0.3125,  0.5537]],

        [[-0.5150,  0.8150, -6.5015,  ..., -0.5377, -0.4171,  0.1350],
         [-2.9979,  1.0930, -0.2619,  ..., -3.1811, -1.0048, -1.8349],
         [-2.8788,  0.5405, -0.0789,  ..., -2.3969, -0.7016, -0.7332],
         ...,
         [-1.7194,  1.5158,  1.0070,  ..., -2.8931, -2.3309,  1.1685],
         [ 0.0717, -0.1039,  0.5084,  ..., -2.1932,  0.0751,  2.8236],
         [-2.4774,  0.7563, -0.7502,  ..., -2.1312, -0.0685,  0.8700]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[[-0.0455,  0.6529,  0.6297,  ...,  0.4139, -0.9381,  0.6769],
         [-3.1104, -3.7282, -2.3953,  ..., -0.9155, -0.8280, -1.7070],
         [-0.9324, -2.9333, -2.3249,  ..., -0.8455, -0.0326, -0.6998],
         ...,
         [-1.9000, -1.1028, -1.1281,  ..., -0.2809,  2.0206, -1.0802],
         [-1.1088, -1.0420, -2.4026,  ..., -0.4478,  0.7391, -0.0354],
         [ 0.7349,  0.6742, -2.6697,  ..., -0.5114,  1.5155,  2.0246]],

        [[-0.0799,  0.7813,  0.4918,  ...,  0.6888, -0.7680,  0.9805],
         [-2.8720, -1.0602, -2.3610,  ..., -2.1143,  0.9664, -1.1212],
         [-1.4705, -2.1384, -1.9955,  ..., -0.9722,  1.5909, -0.1668],
         ...,
         [-2.9884, -1.1566, -2.5215,  ...,  1.1460,  0.7120, -0.6320],
         [-3.3666, -0.7966, -3.3154,  ...,  0.5316,  1.7058,  2.1950],
         [ 0.6626,  0.8537, -2.7251,  ..., -0.0901,  1.5883,  2.3840]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.0455,  0.6529,  0.6297,  ...,  0.4139, -0.9381,  0.6769],
         [-3.1104, -3.7282, -2.3953,  ..., -0.9155, -0.8280, -1.7070],
         [-0.9324, -2.9333, -2.3249,  ..., -0.8455, -0.0326, -0.6998],
         ...,
         [-1.9000, -1.1028, -1.1281,  ..., -0.2809,  2.0206, -1.0802],
         [-1.1088, -1.0420, -2.4026,  ..., -0.4478,  0.7391, -0.0354],
         [ 0.7349,  0.6742, -2.6697,  ..., -0.5114,  1.5155,  2.0246]],

        [[-0.0799,  0.7813,  0.4918,  ...,  0.6888, -0.7680,  0.9805],
         [-2.8720, -1.0602, -2.3610,  ..., -2.1143,  0.9664, -1.1212],
         [-1.4705, -2.1384, -1.9955,  ..., -0.9722,  1.5909, -0.1668],
         ...,
         [-2.9884, -1.1566, -2.5215,  ...,  1.1460,  0.7120, -0.6320],
         [-3.3666, -0.7966, -3.3154,  ...,  0.5316,  1.7058,  2.1950],
         [ 0.6626,  0.8537, -2.7251,  ..., -0.0901,  1.5883,  2.3840]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-0.0455,  0.6529,  0.6297,  ...,  1.0165, -1.6055,  0.0557],
          [-0.9141, -1.5101, -0.8415,  ...,  0.4139, -0.9381,  0.6769]],

         [[-3.1104, -3.7282, -2.3953,  ..., -1.6195,  0.7426, -3.2794],
          [-1.2694,  0.3821, -0.5687,  ..., -0.9155, -0.8280, -1.7070]],

         [[-0.9324, -2.9333, -2.3249,  ..., -1.0254,  1.8158, -1.8835],
          [-1.5265, -0.3901,  0.2734,  ..., -0.8455, -0.0326, -0.6998]],

         ...,

         [[-1.9000, -1.1028, -1.1281,  ..., -1.2688, -0.0851, -2.3190],
          [-2.4374,  0.0718, -2.7276,  ..., -0.2809,  2.0206, -1.0802]],

         [[-1.1088, -1.0420, -2.4026,  ..., -1.0658,  0.1932, -1.7012],
          [-2.3622, -0.5291, -1.9931,  ..., -0.4478,  0.7391, -0.0354]],

         [[ 0.7349,  0.6742, -2.6697,  ..., -1.4630, -0.1686, -2.5682],
          [-0.1401, -0.9712, -2.3801,  ..., -0.5114,  1.5155,  2.0246]]],


        [[[-0.0799,  0.7813,  0.4918,  ...,  1.2364, -1.9500, -0.1275],
          [-0.4080, -1.5069, -0.8504,  ...,  0.6888, -0.7680,  0.9805]],

         [[-2.8720, -1.0602, -2.3610,  ..., -2.3225, -0.0351, -2.7432],
          [-0.2305, -0.5940, -1.1570,  ..., -2.1143,  0.9664, -1.1212]],

         [[-1.4705, -2.1384, -1.9955,  ..., -0.6524, -1.8025, -1.8321],
          [-1.7742, -0.6800, -0.2172,  ..., -0.9722,  1.5909, -0.1668]],

         ...,

         [[-2.9884, -1.1566, -2.5215,  ..., -0.5054, -1.0314, -3.4883],
          [-1.9535,  0.5573, -2.1564,  ...,  1.1460,  0.7120, -0.6320]],

         [[-3.3666, -0.7966, -3.3154,  ...,  0.7587, -0.6289, -3.4848],
          [-1.4099, -2.0919, -1.5870,  ...,  0.5316,  1.7058,  2.1950]],

         [[ 0.6626,  0.8537, -2.7251,  ..., -1.1831, -0.7083, -2.7717],
          [ 0.4486, -1.1639, -2.1203,  ..., -0.0901,  1.5883,  2.3840]]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.8947,  0.0412, -1.2359,  ...,  0.4410, -0.3965,  0.0106],
         [-1.9196,  0.3326,  0.8482,  ..., -1.5790, -1.1817, -1.0156],
         [-2.1664,  0.3959,  0.7476,  ..., -2.1767, -0.6488,  0.1889],
         ...,
         [-1.6009,  0.4887, -0.4818,  ..., -1.1268,  0.4111,  0.7892],
         [-0.1528,  1.1728, -0.5164,  ..., -0.4340,  0.1499,  1.6704],
         [ 1.0253,  1.4222, -0.1805,  ..., -0.6130, -0.5380,  1.6164]],

        [[-0.6736, -0.0718, -1.1724,  ...,  0.2001, -0.5481,  0.0232],
         [-1.8698, -1.2184,  0.2913,  ..., -1.1398, -1.3523, -0.7851],
         [-1.3725, -0.8212,  0.1984,  ..., -1.8218, -1.4800, -0.2956],
         ...,
         [-0.5946,  0.5680,  0.8938,  ..., -1.6653,  0.8218,  1.1902],
         [ 1.2800,  1.9566,  0.2540,  ..., -1.2290,  0.5257,  1.2667],
         [ 1.3511,  1.3329, -0.0782,  ..., -0.8454, -0.7400,  1.5966]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.8947,  0.0412, -1.2359,  ...,  0.4410, -0.3965,  0.0106],
         [-1.9196,  0.3326,  0.8482,  ..., -1.5790, -1.1817, -1.0156],
         [-2.1664,  0.3959,  0.7476,  ..., -2.1767, -0.6488,  0.1889],
         ...,
         [-1.6009,  0.4887, -0.4818,  ..., -1.1268,  0.4111,  0.7892],
         [-0.1528,  1.1728, -0.5164,  ..., -0.4340,  0.1499,  1.6704],
         [ 1.0253,  1.4222, -0.1805,  ..., -0.6130, -0.5380,  1.6164]],

        [[-0.6736, -0.0718, -1.1724,  ...,  0.2001, -0.5481,  0.0232],
         [-1.8698, -1.2184,  0.2913,  ..., -1.1398, -1.3523, -0.7851],
         [-1.3725, -0.8212,  0.1984,  ..., -1.8218, -1.4800, -0.2956],
         ...,
         [-0.5946,  0.5680,  0.8938,  ..., -1.6653,  0.8218,  1.1902],
         [ 1.2800,  1.9566,  0.2540,  ..., -1.2290,  0.5257,  1.2667],
         [ 1.3511,  1.3329, -0.0782,  ..., -0.8454, -0.7400,  1.5966]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-0.8947,  0.0412, -1.2359,  ...,  1.5140, -1.9812, -2.5532],
          [-0.2951, -1.6086, -0.6381,  ...,  0.4410, -0.3965,  0.0106]],

         [[-1.9196,  0.3326,  0.8482,  ..., -2.3348,  1.3935,  1.1452],
          [-0.5277,  0.1234,  0.7865,  ..., -1.5790, -1.1817, -1.0156]],

         [[-2.1664,  0.3959,  0.7476,  ..., -2.0817,  0.2852,  0.8173],
          [-0.8414,  0.5154, -0.4553,  ..., -2.1767, -0.6488,  0.1889]],

         ...,

         [[-1.6009,  0.4887, -0.4818,  ..., -1.8165,  1.4764,  0.5091],
          [-0.6869,  0.4007, -1.5818,  ..., -1.1268,  0.4111,  0.7892]],

         [[-0.1528,  1.1728, -0.5164,  ..., -1.3611,  1.0621,  1.1810],
          [-0.7595, -0.1699, -1.5305,  ..., -0.4340,  0.1499,  1.6704]],

         [[ 1.0253,  1.4222, -0.1805,  ..., -0.6989,  0.4721,  2.6129],
          [-1.2381, -0.4573, -1.7561,  ..., -0.6130, -0.5380,  1.6164]]],


        [[[-0.6736, -0.0718, -1.1724,  ...,  1.4816, -1.7920, -2.5177],
          [-0.3929, -1.5120, -0.5353,  ...,  0.2001, -0.5481,  0.0232]],

         [[-1.8698, -1.2184,  0.2913,  ..., -1.5227,  1.9764,  0.6389],
          [-0.4202,  0.4572, -1.0780,  ..., -1.1398, -1.3523, -0.7851]],

         [[-1.3725, -0.8212,  0.1984,  ..., -2.1553,  1.7041,  0.7166],
          [-1.0124,  0.9351, -0.0954,  ..., -1.8218, -1.4800, -0.2956]],

         ...,

         [[-0.5946,  0.5680,  0.8938,  ..., -2.1904,  1.7986,  1.0902],
          [-1.3820,  1.0268, -1.0041,  ..., -1.6653,  0.8218,  1.1902]],

         [[ 1.2800,  1.9566,  0.2540,  ..., -1.6180,  1.6176,  2.5636],
          [-2.0592,  0.7059, -1.3359,  ..., -1.2290,  0.5257,  1.2667]],

         [[ 1.3511,  1.3329, -0.0782,  ..., -0.5836,  0.6491,  2.6554],
          [-1.3686, -0.2348, -1.7438,  ..., -0.8454, -0.7400,  1.5966]]]],
       grad_fn=<ViewBackward0>)
tensor([[[ 3.8307e-01,  1.8836e-02, -1.0314e+00,  ...,  6.4335e-01,
          -3.1830e-01, -1.7296e+00],
         [ 1.5897e+00,  1.3689e-01, -6.8915e-01,  ...,  1.5973e+00,
           1.1907e+00, -9.0454e-01],
         [-5.8138e-01,  6.7943e-01, -1.3203e+00,  ..., -5.3627e-01,
          -1.0456e+00, -1.6301e+00],
         ...,
         [-1.7466e-01,  3.0706e-02,  7.5225e-01,  ..., -1.3217e+00,
          -1.3415e+00, -3.8328e-01],
         [ 1.5170e-01,  5.1089e-01,  1.3993e-01,  ..., -1.6600e-01,
          -6.5011e-01,  2.1798e-02],
         [-2.4311e-01,  1.6726e+00,  1.6682e-01,  ...,  1.3448e-03,
          -1.6754e+00,  3.1771e-01]],

        [[ 4.7844e-01, -3.1772e-01, -1.0617e+00,  ...,  6.4928e-01,
          -3.2944e-01, -2.4185e+00],
         [ 5.9122e-01, -2.2648e-01,  1.7474e-01,  ..., -1.8623e+00,
          -1.1230e+00,  3.5013e-01],
         [ 1.1642e-01,  1.2460e+00,  7.7942e-02,  ...,  6.4975e-01,
          -7.3862e-01, -2.1510e+00],
         ...,
         [ 1.1291e+00, -1.3637e+00, -1.5779e+00,  ...,  1.7637e+00,
           9.1331e-01, -1.7033e+00],
         [ 1.5909e+00, -1.4922e+00,  1.0060e+00,  ...,  9.8096e-01,
           8.6736e-01, -2.3894e+00],
         [ 1.8451e-01,  1.2740e+00,  4.2857e-01,  ...,  6.2708e-01,
          -1.3601e+00, -3.9984e-01]]], grad_fn=<ViewBackward0>)
tensor([[[ 3.8307e-01,  1.8836e-02, -1.0314e+00,  ...,  6.4335e-01,
          -3.1830e-01, -1.7296e+00],
         [ 1.5897e+00,  1.3689e-01, -6.8915e-01,  ...,  1.5973e+00,
           1.1907e+00, -9.0454e-01],
         [-5.8138e-01,  6.7943e-01, -1.3203e+00,  ..., -5.3627e-01,
          -1.0456e+00, -1.6301e+00],
         ...,
         [-1.7466e-01,  3.0706e-02,  7.5225e-01,  ..., -1.3217e+00,
          -1.3415e+00, -3.8328e-01],
         [ 1.5170e-01,  5.1089e-01,  1.3993e-01,  ..., -1.6600e-01,
          -6.5011e-01,  2.1798e-02],
         [-2.4311e-01,  1.6726e+00,  1.6682e-01,  ...,  1.3448e-03,
          -1.6754e+00,  3.1771e-01]],

        [[ 4.7844e-01, -3.1772e-01, -1.0617e+00,  ...,  6.4928e-01,
          -3.2944e-01, -2.4185e+00],
         [ 5.9122e-01, -2.2648e-01,  1.7474e-01,  ..., -1.8623e+00,
          -1.1230e+00,  3.5013e-01],
         [ 1.1642e-01,  1.2460e+00,  7.7942e-02,  ...,  6.4975e-01,
          -7.3862e-01, -2.1510e+00],
         ...,
         [ 1.1291e+00, -1.3637e+00, -1.5779e+00,  ...,  1.7637e+00,
           9.1331e-01, -1.7033e+00],
         [ 1.5909e+00, -1.4922e+00,  1.0060e+00,  ...,  9.8096e-01,
           8.6736e-01, -2.3894e+00],
         [ 1.8451e-01,  1.2740e+00,  4.2857e-01,  ...,  6.2708e-01,
          -1.3601e+00, -3.9984e-01]]], grad_fn=<ViewBackward0>)
tensor([[[[ 3.8307e-01,  1.8836e-02, -1.0314e+00,  ..., -8.0955e-01,
           -2.8557e-01, -2.3318e-01],
          [-9.9661e-02,  3.1722e-01, -3.0517e-01,  ...,  6.4335e-01,
           -3.1830e-01, -1.7296e+00]],

         [[ 1.5897e+00,  1.3689e-01, -6.8915e-01,  ..., -7.9549e-01,
           -6.9279e-01, -1.8082e-01],
          [-9.9201e-01,  9.4938e-01,  4.4198e-02,  ...,  1.5973e+00,
            1.1907e+00, -9.0454e-01]],

         [[-5.8138e-01,  6.7943e-01, -1.3203e+00,  ..., -1.8905e+00,
            1.6226e-01, -1.2953e+00],
          [-7.0312e-01, -6.4926e-01, -5.0913e-01,  ..., -5.3627e-01,
           -1.0456e+00, -1.6301e+00]],

         ...,

         [[-1.7466e-01,  3.0706e-02,  7.5225e-01,  ..., -1.9281e+00,
            1.1489e+00, -2.4530e-01],
          [-7.6226e-02,  8.5814e-01, -1.5467e+00,  ..., -1.3217e+00,
           -1.3415e+00, -3.8328e-01]],

         [[ 1.5170e-01,  5.1089e-01,  1.3993e-01,  ..., -2.4168e+00,
            3.3385e-01, -6.2115e-02],
          [-1.6390e+00, -1.6085e-01, -1.9118e+00,  ..., -1.6600e-01,
           -6.5011e-01,  2.1798e-02]],

         [[-2.4311e-01,  1.6726e+00,  1.6682e-01,  ..., -1.0481e+00,
           -2.7634e+00,  2.2741e-01],
          [-1.4603e+00,  3.1239e-02,  3.8892e-01,  ...,  1.3448e-03,
           -1.6754e+00,  3.1771e-01]]],


        [[[ 4.7844e-01, -3.1772e-01, -1.0617e+00,  ..., -7.8511e-01,
           -3.2510e-01, -2.1300e-01],
          [-3.9893e-01,  6.1469e-01, -3.9206e-01,  ...,  6.4928e-01,
           -3.2944e-01, -2.4185e+00]],

         [[ 5.9122e-01, -2.2648e-01,  1.7474e-01,  ..., -1.6488e+00,
           -8.6853e-01, -8.0782e-01],
          [-2.1516e+00, -2.4247e-01, -8.1713e-01,  ..., -1.8623e+00,
           -1.1230e+00,  3.5013e-01]],

         [[ 1.1642e-01,  1.2460e+00,  7.7942e-02,  ..., -1.3121e+00,
           -6.7044e-01, -1.1324e+00],
          [ 4.3930e-01,  3.4082e-01, -1.2243e+00,  ...,  6.4975e-01,
           -7.3862e-01, -2.1510e+00]],

         ...,

         [[ 1.1291e+00, -1.3637e+00, -1.5779e+00,  ...,  8.5980e-01,
            3.2796e-01, -1.9442e+00],
          [-8.9502e-01,  9.7357e-01,  8.5447e-01,  ...,  1.7637e+00,
            9.1331e-01, -1.7033e+00]],

         [[ 1.5909e+00, -1.4922e+00,  1.0060e+00,  ..., -1.5965e+00,
           -3.9380e-01, -4.3585e-01],
          [-2.2103e+00,  4.4127e-01,  1.1554e+00,  ...,  9.8096e-01,
            8.6736e-01, -2.3894e+00]],

         [[ 1.8451e-01,  1.2740e+00,  4.2857e-01,  ..., -1.1366e+00,
           -2.8409e+00,  4.6710e-01],
          [-1.9576e+00,  2.0176e-01, -4.1034e-02,  ...,  6.2708e-01,
           -1.3601e+00, -3.9984e-01]]]], grad_fn=<ViewBackward0>)
tensor([[[[ 3.7393e-01,  3.7078e-01, -7.5553e-01,  ..., -9.4280e-01,
           -6.5765e-01, -1.4711e-01],
          [ 3.9396e-01,  4.7657e-02, -1.0277e+00,  ..., -8.4926e-01,
           -2.7815e-01, -2.6627e-01],
          [ 9.9586e-01,  2.4414e-01, -8.3459e-01,  ..., -8.5795e-01,
           -4.3860e-01, -2.1582e-01],
          ...,
          [ 1.4504e+00, -1.2077e-01, -4.8160e-01,  ..., -1.6701e+00,
            6.3807e-01, -8.3788e-02],
          [ 3.8821e-01,  1.4150e-01, -1.5051e-01,  ..., -1.5785e+00,
            4.1589e-01, -1.8024e-01],
          [ 1.0467e-01,  5.5196e-01,  1.0818e-01,  ..., -2.0373e+00,
            8.4396e-02, -6.9034e-02]],

         [[-7.7230e-01, -1.1852e-01, -8.0275e-02,  ..., -1.3790e-03,
           -5.2249e-01, -4.2095e-01],
          [-4.6209e-01, -1.2680e-01, -2.5711e-01,  ...,  1.3235e-01,
           -4.1385e-01, -1.3744e+00],
          [-1.7985e-01, -2.6037e-01,  3.5678e-01,  ...,  2.4736e-01,
           -1.6626e-01, -7.0940e-01],
          ...,
          [-6.2069e-01,  2.5298e-01, -7.8416e-01,  ...,  8.6186e-02,
           -6.9561e-01, -8.5675e-01],
          [-9.1553e-01,  1.4648e-01, -5.5621e-02,  ...,  1.8643e-01,
           -1.0965e+00, -4.8097e-01],
          [-1.0678e+00,  9.0884e-02, -4.5400e-02,  ...,  1.5424e-01,
           -1.1762e+00, -2.9385e-01]]],


        [[[ 4.9445e-01,  9.6661e-03, -5.7118e-01,  ..., -9.7640e-01,
           -9.0948e-01, -1.3761e-01],
          [ 4.7672e-01, -2.3672e-01, -8.9246e-01,  ..., -8.9927e-01,
           -3.9088e-01, -3.1635e-01],
          [ 4.5120e-01, -5.8034e-02, -6.8736e-01,  ..., -1.0352e+00,
           -4.7996e-01, -4.5624e-01],
          ...,
          [ 3.3362e-01,  1.5914e-03, -1.0108e+00,  ..., -1.5704e+00,
           -3.9079e-01, -2.1742e-01],
          [ 1.0123e+00, -1.1147e+00, -1.3032e+00,  ...,  3.4082e-01,
            1.3298e-01, -1.5481e+00],
          [ 1.4273e+00, -1.2951e+00,  5.7719e-01,  ..., -1.2490e+00,
           -4.1355e-01, -5.8558e-01]],

         [[-1.0884e+00,  2.0838e-01, -5.0948e-01,  ...,  2.2845e-01,
           -8.2075e-01, -1.1496e+00],
          [-3.4296e-01,  5.0833e-01, -8.6644e-01,  ...,  3.3008e-01,
           -5.5070e-01, -1.8440e+00],
          [-3.2602e-01,  6.0671e-01, -7.5711e-01,  ...,  4.7493e-01,
           -4.6589e-01, -2.0454e+00],
          ...,
          [-1.7399e+00,  4.6359e-01,  6.5213e-01,  ...,  9.0025e-01,
            3.1643e-01, -2.0784e+00],
          [-1.8375e+00,  2.6226e-01,  3.7727e-02,  ...,  6.6286e-01,
           -1.0588e+00, -7.6781e-01],
          [-1.6402e+00,  3.1250e-01, -6.6090e-04,  ...,  6.6493e-01,
           -9.1110e-01, -1.0382e+00]]]],
       grad_fn=<ScaledDotProductFlashAttentionForCpuBackward0>)
tensor([[[[ 3.7393e-01,  3.7078e-01, -7.5553e-01,  ..., -9.4280e-01,
           -6.5765e-01, -1.4711e-01],
          [-7.7230e-01, -1.1852e-01, -8.0275e-02,  ..., -1.3790e-03,
           -5.2249e-01, -4.2095e-01]],

         [[ 3.9396e-01,  4.7657e-02, -1.0277e+00,  ..., -8.4926e-01,
           -2.7815e-01, -2.6627e-01],
          [-4.6209e-01, -1.2680e-01, -2.5711e-01,  ...,  1.3235e-01,
           -4.1385e-01, -1.3744e+00]],

         [[ 9.9586e-01,  2.4414e-01, -8.3459e-01,  ..., -8.5795e-01,
           -4.3860e-01, -2.1582e-01],
          [-1.7985e-01, -2.6037e-01,  3.5678e-01,  ...,  2.4736e-01,
           -1.6626e-01, -7.0940e-01]],

         ...,

         [[ 1.4504e+00, -1.2077e-01, -4.8160e-01,  ..., -1.6701e+00,
            6.3807e-01, -8.3788e-02],
          [-6.2069e-01,  2.5298e-01, -7.8416e-01,  ...,  8.6186e-02,
           -6.9561e-01, -8.5675e-01]],

         [[ 3.8821e-01,  1.4150e-01, -1.5051e-01,  ..., -1.5785e+00,
            4.1589e-01, -1.8024e-01],
          [-9.1553e-01,  1.4648e-01, -5.5621e-02,  ...,  1.8643e-01,
           -1.0965e+00, -4.8097e-01]],

         [[ 1.0467e-01,  5.5196e-01,  1.0818e-01,  ..., -2.0373e+00,
            8.4396e-02, -6.9034e-02],
          [-1.0678e+00,  9.0884e-02, -4.5400e-02,  ...,  1.5424e-01,
           -1.1762e+00, -2.9385e-01]]],


        [[[ 4.9445e-01,  9.6661e-03, -5.7118e-01,  ..., -9.7640e-01,
           -9.0948e-01, -1.3761e-01],
          [-1.0884e+00,  2.0838e-01, -5.0948e-01,  ...,  2.2845e-01,
           -8.2075e-01, -1.1496e+00]],

         [[ 4.7672e-01, -2.3672e-01, -8.9246e-01,  ..., -8.9927e-01,
           -3.9088e-01, -3.1635e-01],
          [-3.4296e-01,  5.0833e-01, -8.6644e-01,  ...,  3.3008e-01,
           -5.5070e-01, -1.8440e+00]],

         [[ 4.5120e-01, -5.8034e-02, -6.8736e-01,  ..., -1.0352e+00,
           -4.7996e-01, -4.5624e-01],
          [-3.2602e-01,  6.0671e-01, -7.5711e-01,  ...,  4.7493e-01,
           -4.6589e-01, -2.0454e+00]],

         ...,

         [[ 3.3362e-01,  1.5914e-03, -1.0108e+00,  ..., -1.5704e+00,
           -3.9079e-01, -2.1742e-01],
          [-1.7399e+00,  4.6359e-01,  6.5213e-01,  ...,  9.0025e-01,
            3.1643e-01, -2.0784e+00]],

         [[ 1.0123e+00, -1.1147e+00, -1.3032e+00,  ...,  3.4082e-01,
            1.3298e-01, -1.5481e+00],
          [-1.8375e+00,  2.6226e-01,  3.7727e-02,  ...,  6.6286e-01,
           -1.0588e+00, -7.6781e-01]],

         [[ 1.4273e+00, -1.2951e+00,  5.7719e-01,  ..., -1.2490e+00,
           -4.1355e-01, -5.8558e-01],
          [-1.6402e+00,  3.1250e-01, -6.6090e-04,  ...,  6.6493e-01,
           -9.1110e-01, -1.0382e+00]]]], grad_fn=<TransposeBackward0>)
tensor([[-0.0073, -0.0693],
        [ 0.0201, -0.0920]], grad_fn=<AddmmBackward0>)
tensor([1, 0])

Full Supervised Finetuning (SFT)#

Before training the model, let’s inspect how many trainable parameters there are. If you’re familiar with Keras, you might have used the model.summary() API before, but it’s not as easy to do the same in Pytorch - luckily, Mase has a module-level pass with this functionality.

from chop.passes.module import report_trainable_parameters_analysis_pass

_, _ = report_trainable_parameters_analysis_pass(mg.model)
+-------------------------------------------------+------------------------+
| Submodule                                       |   Trainable Parameters |
+=================================================+========================+
| bert                                            |                4385920 |
+-------------------------------------------------+------------------------+
| bert.embeddings                                 |                3972864 |
+-------------------------------------------------+------------------------+
| bert.embeddings.word_embeddings                 |                3906816 |
+-------------------------------------------------+------------------------+
| bert.embeddings.token_type_embeddings           |                    256 |
+-------------------------------------------------+------------------------+
| bert.embeddings.position_embeddings             |                  65536 |
+-------------------------------------------------+------------------------+
| bert.embeddings.LayerNorm                       |                    256 |
+-------------------------------------------------+------------------------+
| bert.embeddings.dropout                         |                      0 |
+-------------------------------------------------+------------------------+
| bert.encoder                                    |                 396544 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer                              |                 396544 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0                            |                 198272 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention                  |                  66304 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self             |                  49536 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query       |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key         |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value       |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output           |                  16768 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense     |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dropout   |                      0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.LayerNorm |                    256 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate               |                  66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense         |                  66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output                     |                  65920 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense               |                  65664 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dropout             |                      0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.LayerNorm           |                    256 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1                            |                 198272 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention                  |                  66304 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self             |                  49536 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query       |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key         |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value       |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output           |                  16768 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense     |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dropout   |                      0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.LayerNorm |                    256 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate               |                  66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense         |                  66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output                     |                  65920 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense               |                  65664 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dropout             |                      0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.LayerNorm           |                    256 |
+-------------------------------------------------+------------------------+
| bert.pooler                                     |                  16512 |
+-------------------------------------------------+------------------------+
| bert.pooler.dense                               |                  16512 |
+-------------------------------------------------+------------------------+
| bert.pooler.activation                          |                      0 |
+-------------------------------------------------+------------------------+
| dropout                                         |                      0 |
+-------------------------------------------------+------------------------+
| classifier                                      |                    258 |
+-------------------------------------------------+------------------------+
| crossentropyloss_0                              |                      0 |
+-------------------------------------------------+------------------------+

Total Trainable Parameters: 14480258

From this, we can see the majority of the trainable parameters are in the Embedding layer. We don’t need to train this, so we freeze those parameters in the cell below.

for param in mg.model.bert.embeddings.parameters():
    param.requires_grad = False

To train the model, we rely on the Trainer class from the transformers library, which makes it easy to set up a training loop with any hardware configuration. The get_trainer utility in Mase handles assigning the training arguments to the Trainer class for common use cases, such as in this tutorial.

from chop.tools import get_trainer

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)

Before running any fine tuning, let’s see how the model performs out of the box. Without any fine-tuning, we can see the model just performs a random guess - there are two labels in the dataset, so this corresponds to an accuracy of around 50%.

# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
100%|██████████| 3125/3125 [01:07<00:00, 46.59it/s]
Evaluation accuracy: 0.52164

Now, run the cell below to execute a single training epoch with the current setup.

trainer.train()

Let’s see how much accuracy we get after a single training epoch of full finetuning.

eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

We can now export the SFT version of the model to be used in later tutorials.

from pathlib import Path

mg.export(f"{Path.home()}/tutorial_2_sft")

Parameter Efficient Finetuning (PEFT) with LoRA#

An alternative to full fine-tuning is Parameter Efficient Fine Tuning (PEFT), which uses a small number of trainable parameters to achieve similar performance. LoRA was proposed by a research team at Microsoft in 2021, as an efficient technique for PEFT.

drawing

Consider the standard equation of a linear layer:

\[ y = X W + b \]

The LoRA method involves replacing this with the following, where A and B are low-rank matrices. We freeze the \(W\) parameters, and only allow the optimizer to train the parameters in \(A\) and \(B\).

\[ y = X (W + AB) + b \]

This enables us to achieve accuracies comparable to full fine tuning, while only training a fraction of the parameters. See the paper for more details. We can inject the LoRA adapter into the existing model using the insert_lora_adapter_transform_pass pass in Mase, as follows.

mg, _ = passes.insert_lora_adapter_transform_pass(
    mg,
    pass_args={
        "rank": 6,
        "alpha": 1.0,
        "dropout": 0.5,
    },
)
INFO     Replaced node: bert_encoder_layer_0_attention_self_query, target: bert.encoder.layer.0.attention.self.query with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_0_attention_self_key, target: bert.encoder.layer.0.attention.self.key with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_0_attention_self_value, target: bert.encoder.layer.0.attention.self.value with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_0_attention_output_dense, target: bert.encoder.layer.0.attention.output.dense with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_0_intermediate_dense, target: bert.encoder.layer.0.intermediate.dense with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_0_output_dense, target: bert.encoder.layer.0.output.dense with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_attention_self_query, target: bert.encoder.layer.1.attention.self.query with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_attention_self_key, target: bert.encoder.layer.1.attention.self.key with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_attention_self_value, target: bert.encoder.layer.1.attention.self.value with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_attention_output_dense, target: bert.encoder.layer.1.attention.output.dense with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_intermediate_dense, target: bert.encoder.layer.1.intermediate.dense with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_output_dense, target: bert.encoder.layer.1.output.dense with LoRALinear module.
INFO     Replaced node: bert_pooler_dense, target: bert.pooler.dense with LoRALinear module.
INFO     Replaced node: classifier, target: classifier with LoRALinear module.

Similar to before, let’s report the number of trainable parameters.

_, _ = report_trainable_parameters_analysis_pass(mg.model)
+-----------------------------------------------------+------------------------+
| Submodule                                           |   Trainable Parameters |
+=====================================================+========================+
| bert                                                |                 439808 |
+-----------------------------------------------------+------------------------+
| bert.embeddings                                     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.word_embeddings                     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.token_type_embeddings               |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.position_embeddings                 |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.LayerNorm                           |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.dropout                             |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder                                        |                 421888 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer                                  |                 421888 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0                                |                 210944 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention                      |                  71936 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self                 |                  53760 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query           |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.linear    |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.lora_a    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.lora_b    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.dropout   |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key             |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.linear      |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.lora_a      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.lora_b      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.dropout     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value           |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.linear    |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.lora_a    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.lora_b    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.dropout   |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output               |                  18176 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense         |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.linear  |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.lora_a  |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.lora_b  |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.dropout |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dropout       |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.LayerNorm     |                    256 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate                   |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense             |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.linear      |                  65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.lora_a      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.lora_b      |                   3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.dropout     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output                         |                  69632 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense                   |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.linear            |                  65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.lora_a            |                   3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.lora_b            |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.dropout           |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dropout                 |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.LayerNorm               |                    256 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1                                |                 210944 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention                      |                  71936 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self                 |                  53760 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query           |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.linear    |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.lora_a    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.lora_b    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.dropout   |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key             |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.linear      |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.lora_a      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.lora_b      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.dropout     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value           |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.linear    |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.lora_a    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.lora_b    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.dropout   |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output               |                  18176 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense         |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.linear  |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.lora_a  |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.lora_b  |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.dropout |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dropout       |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.LayerNorm     |                    256 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate                   |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense             |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.linear      |                  65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.lora_a      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.lora_b      |                   3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.dropout     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output                         |                  69632 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense                   |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.linear            |                  65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.lora_a            |                   3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.lora_b            |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.dropout           |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dropout                 |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.LayerNorm               |                    256 |
+-----------------------------------------------------+------------------------+
| bert.pooler                                         |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense                                   |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.linear                            |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.lora_a                            |                    768 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.lora_b                            |                    768 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.dropout                           |                      0 |
+-----------------------------------------------------+------------------------+
| bert.pooler.activation                              |                      0 |
+-----------------------------------------------------+------------------------+
| dropout                                             |                      0 |
+-----------------------------------------------------+------------------------+
| classifier                                          |                   1036 |
+-----------------------------------------------------+------------------------+
| classifier.linear                                   |                    256 |
+-----------------------------------------------------+------------------------+
| classifier.lora_a                                   |                    768 |
+-----------------------------------------------------+------------------------+
| classifier.lora_b                                   |                     12 |
+-----------------------------------------------------+------------------------+
| classifier.dropout                                  |                      0 |
+-----------------------------------------------------+------------------------+
| crossentropyloss_0                                  |                      0 |
+-----------------------------------------------------+------------------------+

Total Trainable Parameters: 3169816

In this case, LoRA reduces the number of trainable parameters by \(4.5\times\)! We’ll run a few more training epochs and evaluate the resulting accuracy.

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
    num_train_epochs=1,
)
trainer.train()

# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
 16%|█▌        | 500/3125 [00:40<03:29, 12.52it/s]
{'loss': 0.6402, 'grad_norm': 2.730722665786743, 'learning_rate': 4.2e-05, 'epoch': 0.16}
 32%|███▏      | 1000/3125 [01:13<03:24, 10.41it/s]
{'loss': 0.5216, 'grad_norm': 5.168735504150391, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.32}
 48%|████▊     | 1501/3125 [01:44<01:20, 20.05it/s]
{'loss': 0.4751, 'grad_norm': 13.205789566040039, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.48}
 64%|██████▍   | 2001/3125 [02:11<01:44, 10.79it/s]
{'loss': 0.4376, 'grad_norm': 14.12922191619873, 'learning_rate': 1.8e-05, 'epoch': 0.64}
 80%|████████  | 2500/3125 [02:37<00:24, 25.37it/s]
{'loss': 0.4233, 'grad_norm': 4.498361587524414, 'learning_rate': 1e-05, 'epoch': 0.8}
 96%|█████████▌| 3002/3125 [03:01<00:10, 11.99it/s]
{'loss': 0.4164, 'grad_norm': 8.473535537719727, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.96}
100%|██████████| 3125/3125 [03:07<00:00, 16.69it/s]
{'train_runtime': 187.2829, 'train_samples_per_second': 133.488, 'train_steps_per_second': 16.686, 'train_loss': 0.48366960693359373, 'epoch': 1.0}
100%|██████████| 3125/3125 [01:08<00:00, 45.64it/s]
Evaluation accuracy: 0.8218

After training is finished, we can run the fuse_lora_weights_transform_pass pass to optimize the model for inference. This pass replaces each LoRALinear instance with an nn.Linear module, where the \(AB\) product added to the original weights matrix. This incurs less kernel invocations when deploying the model, which reduces inference runtime.

mg, _ = passes.fuse_lora_weights_transform_pass(mg)
eval_results = trainer.evaluate()
100%|██████████| 3125/3125 [01:26<00:00, 36.31it/s]
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
Evaluation accuracy: 0.8218

Conclusion#

Finally, export the finetuned model to be used in future tutorials.

from pathlib import Path

mg.export(f"{Path.home()}/tutorial_2_lora")
INFO     Exporting MaseGraph to /Users/yz10513/tutorial_2_lora.pt, /Users/yz10513/tutorial_2_lora.mz
INFO     Exporting GraphModule to /Users/yz10513/tutorial_2_lora.pt
INFO     Exporting MaseMetadata to /Users/yz10513/tutorial_2_lora.mz
WARNING  Failed to pickle call_function node: finfo
WARNING  cannot pickle 'torch.finfo' object
WARNING  Failed to pickle call_function node: getattr_2
WARNING  cannot pickle 'torch.finfo' object