Tutorial 2: Finetuning Bert for Sequence Classification using a LoRA adapter#

When we import a pretrained transformer model from HuggingFace, we receive the encoder/decoder weights, which aren’t that useful on their own - to perform a useful task such as sequence classification, we add a classification head on top of the model and train those weights on the required dataset. In this tutorial, we’ll look at fine tuning a Bert model for sequence classification with two approaches. First, we’ll attempt full Supervised Fine Tuning (SFT). Then, we’ll use the Mase stack to add a LoRA adapter to the model. We’ll look at the effect in memory requirement for training and the achieved accuracy.

checkpoint = "DeepWokLab/bert-tiny"
tokenizer_checkpoint = "DeepWokLab/bert-tiny"
dataset_name = "imdb"

Sentiment Analysis with the IMDb Dataset#

The IMDB dataset, introduced in this 2011 paper from Stanford, is commonly used for sentiment analysis in the Natural Language Processing (NLP) community. This is a collection of 50k movie reviews from the IMDb website, labelled as either “positive” or “negative”. Here is an example of a positive review:

I turned over to this film in the middle of the night and very nearly skipped right passed it. It was only because there was nothing else on that I decided to watch it. In the end, I thought it was great.

An interesting storyline, good characters, a clever script and brilliant directing makes this a fine film to sit down and watch. This was, in fact, the first I’d heard of this movie, but I would have been happy to have paid money to see this at the cinema.

My IMDB Rating : 8 out of 10

And a negative review:

its a totally average film with a few semi-alright action sequences that make the plot seem a little better and remind the viewer of the classic van dam films. parts of the plot don’t make sense and seem to be added in to use up time. the end plot is that of a very basic type that doesn’t leave the viewer guessing and any twists are obvious from the beginning. the end scene with the flask backs don’t make sense as they are added in and seem to have little relevance to the history of van dam’s character. not really worth watching again, bit disappointed in the end production, even though it is apparent it was shot on a low budget certain shots and sections in the film are of poor directed quality

The dataset is available from HuggingFace through the datasets library. We use the get_tokenized_dataset utility in Mase to automatically tokenize it.

from chop.tools import get_tokenized_dataset

dataset, tokenizer = get_tokenized_dataset(
    dataset=dataset_name,
    checkpoint=tokenizer_checkpoint,
    return_tokenizer=True,
)
/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
INFO     Tokenizing dataset imdb with AutoTokenizer for DeepWokLab/bert-tiny.
Map: 100%|██████████| 25000/25000 [00:03<00:00, 6535.11 examples/s]

Generate a MaseGraph with Custom Arguments#

By inspecting the implementation of the Bert model in HuggingFace, we can see the forward function has a signature similar to the following.

    class BertForSequenceClassification(BertPreTrainedModel):
        def __init__(self, config):
            super().__init__(config)
            self.bert = BertModel(config)
            ...

        def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            token_type_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            head_mask: Optional[torch.Tensor] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
        ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
            ...

By default, the MaseGraph constructor chooses to use the input_ids argument, ignoring the other optional arguments. However, you can specify which inputs to drive during symbolic tracing using the hf_input_names argument. In the following cell, we also drive the attention_mask and labels inputs. By specifying the labels argument, we include a nn.CrossEntropyLoss module at the end of the model to calculate the loss directly.

Task: Remove the attention_mask and labels arguments from the hf_input_names list and re-run the following cell. Use mg.draw() to visualize the graph in each case. Can you see any changes in the graph topology? Can you explain why this happens?

from transformers import AutoModelForSequenceClassification

from chop import MaseGraph
import chop.passes as passes

model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.config.problem_type = "single_label_classification"

mg = MaseGraph(
    model,
    hf_input_names=[
        "input_ids",
        "attention_mask",
        "labels",
    ],
)

mg, _ = passes.init_metadata_analysis_pass(mg)
mg, _ = passes.add_common_metadata_analysis_pass(mg)
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at DeepWokLab/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.
INFO     Getting dummy input for DeepWokLab/bert-tiny.
tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154,  102],
        [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877,  102]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154,  102],
        [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877,  102]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])
tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
tensor([[[ 0.7973,  0.0109, -8.8405,  ...,  1.4170,  0.1046, -0.1551],
         [-1.1766,  1.2879, -1.0986,  ...,  0.4749, -0.5899,  0.8746],
         [-2.0560,  0.7748, -0.8909,  ..., -0.4034,  0.5352, -1.3657],
         ...,
         [ 0.2317, -0.7896,  0.9634,  ..., -0.8037,  0.4834, -0.5868],
         [ 0.0243, -1.0235, -1.2771,  ..., -2.2378,  1.8530,  0.1558],
         [-1.3637,  0.7055, -0.2177,  ...,  0.3557, -0.3971, -0.3107]],

        [[ 0.7973,  0.0109, -8.8405,  ...,  1.4170,  0.1046, -0.1551],
         [-2.6940,  0.6198, -0.4564,  ..., -1.4367, -1.5705, -3.1260],
         [-1.7524,  0.8535, -0.2155,  ..., -0.5222, -1.2430, -1.7199],
         ...,
         [-0.0347,  0.7446,  1.4462,  ..., -1.1578, -2.6197,  0.2612],
         [ 2.4334, -0.3068,  0.8250,  ...,  0.1475,  0.1790,  2.2907],
         [-1.3637,  0.7055, -0.2177,  ...,  0.3557, -0.3971, -0.3107]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[[ 9.7740e-01,  2.5482e-03, -5.2921e-01,  ...,  1.4757e-01,
           1.8900e-01,  2.8282e-01],
         [-3.5020e-02, -6.1047e-02, -1.0465e-01,  ..., -8.1892e-01,
           1.1978e+00,  2.1808e+00],
         [ 4.3982e-01, -1.9602e+00, -6.8830e-01,  ..., -6.3025e-01,
          -1.5967e-01,  1.3284e+00],
         ...,
         [ 1.4783e+00,  1.0907e-01, -1.5222e+00,  ..., -3.0983e-01,
          -1.2971e-01,  1.1265e+00],
         [ 1.5890e+00, -1.6859e+00,  7.8703e-01,  ..., -1.3174e+00,
           2.2258e-01,  8.8157e-01],
         [-3.7517e-01,  1.5191e+00, -2.6796e-01,  ..., -1.6159e+00,
           7.2677e-02,  1.1724e-01]],

        [[ 9.7740e-01,  2.5482e-03, -5.2921e-01,  ...,  1.4757e-01,
           1.8900e-01,  2.8282e-01],
         [-4.4781e-01, -7.9224e-01, -2.1741e+00,  ..., -5.9181e-01,
           1.4373e+00,  2.4267e+00],
         [-2.5942e-01,  9.7163e-01, -3.2928e+00,  ..., -5.9773e-01,
          -3.0482e-01,  1.4038e+00],
         ...,
         [ 5.1575e-02,  3.5218e-01, -3.8926e-01,  ..., -1.1508e+00,
           7.5490e-01,  8.2911e-01],
         [ 1.6107e+00,  6.8170e-02,  9.2537e-01,  ..., -1.5233e+00,
          -6.0733e-01,  3.3097e-01],
         [-3.7517e-01,  1.5191e+00, -2.6796e-01,  ..., -1.6159e+00,
           7.2677e-02,  1.1724e-01]]], grad_fn=<ViewBackward0>)
tensor([[[ 9.7740e-01,  2.5482e-03, -5.2921e-01,  ...,  1.4757e-01,
           1.8900e-01,  2.8282e-01],
         [-3.5020e-02, -6.1047e-02, -1.0465e-01,  ..., -8.1892e-01,
           1.1978e+00,  2.1808e+00],
         [ 4.3982e-01, -1.9602e+00, -6.8830e-01,  ..., -6.3025e-01,
          -1.5967e-01,  1.3284e+00],
         ...,
         [ 1.4783e+00,  1.0907e-01, -1.5222e+00,  ..., -3.0983e-01,
          -1.2971e-01,  1.1265e+00],
         [ 1.5890e+00, -1.6859e+00,  7.8703e-01,  ..., -1.3174e+00,
           2.2258e-01,  8.8157e-01],
         [-3.7517e-01,  1.5191e+00, -2.6796e-01,  ..., -1.6159e+00,
           7.2677e-02,  1.1724e-01]],

        [[ 9.7740e-01,  2.5482e-03, -5.2921e-01,  ...,  1.4757e-01,
           1.8900e-01,  2.8282e-01],
         [-4.4781e-01, -7.9224e-01, -2.1741e+00,  ..., -5.9181e-01,
           1.4373e+00,  2.4267e+00],
         [-2.5942e-01,  9.7163e-01, -3.2928e+00,  ..., -5.9773e-01,
          -3.0482e-01,  1.4038e+00],
         ...,
         [ 5.1575e-02,  3.5218e-01, -3.8926e-01,  ..., -1.1508e+00,
           7.5490e-01,  8.2911e-01],
         [ 1.6107e+00,  6.8170e-02,  9.2537e-01,  ..., -1.5233e+00,
          -6.0733e-01,  3.3097e-01],
         [-3.7517e-01,  1.5191e+00, -2.6796e-01,  ..., -1.6159e+00,
           7.2677e-02,  1.1724e-01]]], grad_fn=<ViewBackward0>)
tensor([[[[ 9.7740e-01,  2.5482e-03, -5.2921e-01,  ..., -5.2614e-01,
           -3.5687e-01, -2.6793e-01],
          [ 2.0335e-01, -5.4534e-01,  3.0686e-01,  ...,  1.4757e-01,
            1.8900e-01,  2.8282e-01]],

         [[-3.5020e-02, -6.1047e-02, -1.0465e-01,  ..., -2.0138e+00,
            4.5529e-01, -7.8171e-01],
          [ 1.1969e+00,  1.6337e+00,  2.5047e-01,  ..., -8.1892e-01,
            1.1978e+00,  2.1808e+00]],

         [[ 4.3982e-01, -1.9602e+00, -6.8830e-01,  ..., -2.2501e-01,
            7.2290e-02, -1.8290e+00],
          [ 8.9952e-01,  1.0029e+00,  7.4536e-04,  ..., -6.3025e-01,
           -1.5967e-01,  1.3284e+00]],

         ...,

         [[ 1.4783e+00,  1.0907e-01, -1.5222e+00,  ...,  1.1867e+00,
           -1.3561e+00,  6.5158e-01],
          [ 9.5466e-01,  4.5887e-01,  7.8078e-01,  ..., -3.0983e-01,
           -1.2971e-01,  1.1265e+00]],

         [[ 1.5890e+00, -1.6859e+00,  7.8703e-01,  ...,  6.5467e-01,
           -6.8451e-01,  6.5081e-01],
          [ 7.0729e-01,  1.4499e+00, -1.5089e-01,  ..., -1.3174e+00,
            2.2258e-01,  8.8157e-01]],

         [[-3.7517e-01,  1.5191e+00, -2.6796e-01,  ...,  3.3130e-01,
           -3.2756e-01, -6.3130e-01],
          [ 8.6773e-01,  2.0996e-01, -3.4332e-01,  ..., -1.6159e+00,
            7.2677e-02,  1.1724e-01]]],


        [[[ 9.7740e-01,  2.5482e-03, -5.2921e-01,  ..., -5.2614e-01,
           -3.5687e-01, -2.6793e-01],
          [ 2.0335e-01, -5.4534e-01,  3.0686e-01,  ...,  1.4757e-01,
            1.8900e-01,  2.8282e-01]],

         [[-4.4781e-01, -7.9224e-01, -2.1741e+00,  ...,  1.7508e+00,
           -3.6708e-01, -1.3251e+00],
          [ 7.9208e-01, -1.3537e-01,  2.3756e-01,  ..., -5.9181e-01,
            1.4373e+00,  2.4267e+00]],

         [[-2.5942e-01,  9.7163e-01, -3.2928e+00,  ..., -9.6646e-01,
           -4.8876e-01, -1.4426e+00],
          [ 1.0250e+00, -6.9093e-01, -1.2734e+00,  ..., -5.9773e-01,
           -3.0482e-01,  1.4038e+00]],

         ...,

         [[ 5.1575e-02,  3.5218e-01, -3.8926e-01,  ..., -1.2252e-02,
            1.0394e+00,  4.2402e-01],
          [-4.7386e-01,  2.6401e+00,  1.7024e+00,  ..., -1.1508e+00,
            7.5490e-01,  8.2911e-01]],

         [[ 1.6107e+00,  6.8170e-02,  9.2537e-01,  ..., -6.1665e-01,
            2.7627e-01, -1.2083e+00],
          [ 9.3395e-01, -9.7541e-01, -2.5442e-02,  ..., -1.5233e+00,
           -6.0733e-01,  3.3097e-01]],

         [[-3.7517e-01,  1.5191e+00, -2.6796e-01,  ...,  3.3130e-01,
           -3.2756e-01, -6.3130e-01],
          [ 8.6773e-01,  2.0996e-01, -3.4332e-01,  ..., -1.6159e+00,
            7.2677e-02,  1.1724e-01]]]], grad_fn=<ViewBackward0>)
tensor([[[-0.1709,  0.5230, -0.8713,  ..., -1.3382,  0.5892,  0.4026],
         [-0.5842,  0.9588,  1.5642,  ..., -1.0731, -0.7330,  0.3132],
         [-0.8601, -1.3756,  0.5042,  ..., -0.0476,  0.2650,  1.2150],
         ...,
         [ 0.0520,  1.1719, -1.5471,  ..., -0.7894,  0.1419,  1.6964],
         [ 0.7654, -1.5053, -0.4142,  ..., -1.4622, -0.8975,  1.4576],
         [-1.2008, -0.6008, -1.4608,  ..., -1.2105, -0.4289,  0.3827]],

        [[-0.1709,  0.5230, -0.8713,  ..., -1.3382,  0.5892,  0.4026],
         [-1.3806,  0.2626, -0.5207,  ..., -1.6714, -0.0554,  1.0225],
         [-1.7116,  1.8788, -2.5695,  ..., -0.6958,  0.5728,  0.5461],
         ...,
         [-1.3246,  1.2196, -0.3034,  ..., -1.1955, -0.6708,  0.5128],
         [ 0.9854,  0.8260,  0.2892,  ..., -0.6428,  0.3637,  0.4339],
         [-1.2008, -0.6008, -1.4608,  ..., -1.2105, -0.4289,  0.3827]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.1709,  0.5230, -0.8713,  ..., -1.3382,  0.5892,  0.4026],
         [-0.5842,  0.9588,  1.5642,  ..., -1.0731, -0.7330,  0.3132],
         [-0.8601, -1.3756,  0.5042,  ..., -0.0476,  0.2650,  1.2150],
         ...,
         [ 0.0520,  1.1719, -1.5471,  ..., -0.7894,  0.1419,  1.6964],
         [ 0.7654, -1.5053, -0.4142,  ..., -1.4622, -0.8975,  1.4576],
         [-1.2008, -0.6008, -1.4608,  ..., -1.2105, -0.4289,  0.3827]],

        [[-0.1709,  0.5230, -0.8713,  ..., -1.3382,  0.5892,  0.4026],
         [-1.3806,  0.2626, -0.5207,  ..., -1.6714, -0.0554,  1.0225],
         [-1.7116,  1.8788, -2.5695,  ..., -0.6958,  0.5728,  0.5461],
         ...,
         [-1.3246,  1.2196, -0.3034,  ..., -1.1955, -0.6708,  0.5128],
         [ 0.9854,  0.8260,  0.2892,  ..., -0.6428,  0.3637,  0.4339],
         [-1.2008, -0.6008, -1.4608,  ..., -1.2105, -0.4289,  0.3827]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-0.1709,  0.5230, -0.8713,  ...,  0.4365,  0.6238, -0.9414],
          [-1.3731,  1.1521,  0.1321,  ..., -1.3382,  0.5892,  0.4026]],

         [[-0.5842,  0.9588,  1.5642,  ..., -1.5431,  0.4999, -1.1350],
          [ 0.9615,  0.8694,  0.0998,  ..., -1.0731, -0.7330,  0.3132]],

         [[-0.8601, -1.3756,  0.5042,  ...,  0.9764, -0.8321, -1.0204],
          [ 1.5175,  1.1454,  0.7791,  ..., -0.0476,  0.2650,  1.2150]],

         ...,

         [[ 0.0520,  1.1719, -1.5471,  ...,  1.9402, -1.1294,  0.4793],
          [ 1.0053,  0.8099,  1.6415,  ..., -0.7894,  0.1419,  1.6964]],

         [[ 0.7654, -1.5053, -0.4142,  ...,  1.7455, -0.7326,  1.5248],
          [ 1.0806,  1.1457,  2.2163,  ..., -1.4622, -0.8975,  1.4576]],

         [[-1.2008, -0.6008, -1.4608,  ...,  2.0905,  1.8849, -1.5708],
          [ 1.9999,  0.3493, -0.8524,  ..., -1.2105, -0.4289,  0.3827]]],


        [[[-0.1709,  0.5230, -0.8713,  ...,  0.4365,  0.6238, -0.9414],
          [-1.3731,  1.1521,  0.1321,  ..., -1.3382,  0.5892,  0.4026]],

         [[-1.3806,  0.2626, -0.5207,  ...,  1.6517, -0.2316, -1.3171],
          [ 0.6812, -0.0090,  0.3803,  ..., -1.6714, -0.0554,  1.0225]],

         [[-1.7116,  1.8788, -2.5695,  ...,  0.4927, -0.4850, -1.0645],
          [ 1.2646,  1.6481,  0.9055,  ..., -0.6958,  0.5728,  0.5461]],

         ...,

         [[-1.3246,  1.2196, -0.3034,  ...,  1.2747,  1.2353,  0.2825],
          [ 1.5373,  0.8648,  0.6062,  ..., -1.1955, -0.6708,  0.5128]],

         [[ 0.9854,  0.8260,  0.2892,  ...,  1.3848, -0.0103, -1.0700],
          [ 1.3827,  2.9809,  0.0276,  ..., -0.6428,  0.3637,  0.4339]],

         [[-1.2008, -0.6008, -1.4608,  ...,  2.0905,  1.8849, -1.5708],
          [ 1.9999,  0.3493, -0.8524,  ..., -1.2105, -0.4289,  0.3827]]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.0123,  0.5761,  0.2209,  ..., -0.1027,  1.1061, -2.5200],
         [-1.1465, -1.5578, -0.6984,  ...,  1.0310,  0.4824, -0.2291],
         [-1.0361, -1.8192, -2.3055,  ...,  1.5286, -1.5941,  1.1762],
         ...,
         [-0.7992,  0.0886,  0.4887,  ..., -1.7941,  0.4835,  1.3780],
         [-1.4692, -0.9135, -0.2802,  ..., -0.9691,  0.3500,  1.8863],
         [-0.5760, -0.0452,  0.4230,  ..., -0.7179, -0.7858,  1.6879]],

        [[-0.0123,  0.5761,  0.2209,  ..., -0.1027,  1.1061, -2.5200],
         [-0.3700, -1.9754, -0.7315,  ...,  0.2293,  0.6996,  3.1299],
         [-0.6252,  0.2879, -1.4036,  ..., -2.0560, -2.4623, -0.9584],
         ...,
         [-1.1306, -1.4343, -1.4422,  ..., -1.6115, -0.0475,  1.3975],
         [-0.9816, -1.4909, -1.0086,  ..., -0.9284,  0.5260,  1.5330],
         [-0.5760, -0.0452,  0.4230,  ..., -0.7179, -0.7858,  1.6879]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.0123,  0.5761,  0.2209,  ..., -0.1027,  1.1061, -2.5200],
         [-1.1465, -1.5578, -0.6984,  ...,  1.0310,  0.4824, -0.2291],
         [-1.0361, -1.8192, -2.3055,  ...,  1.5286, -1.5941,  1.1762],
         ...,
         [-0.7992,  0.0886,  0.4887,  ..., -1.7941,  0.4835,  1.3780],
         [-1.4692, -0.9135, -0.2802,  ..., -0.9691,  0.3500,  1.8863],
         [-0.5760, -0.0452,  0.4230,  ..., -0.7179, -0.7858,  1.6879]],

        [[-0.0123,  0.5761,  0.2209,  ..., -0.1027,  1.1061, -2.5200],
         [-0.3700, -1.9754, -0.7315,  ...,  0.2293,  0.6996,  3.1299],
         [-0.6252,  0.2879, -1.4036,  ..., -2.0560, -2.4623, -0.9584],
         ...,
         [-1.1306, -1.4343, -1.4422,  ..., -1.6115, -0.0475,  1.3975],
         [-0.9816, -1.4909, -1.0086,  ..., -0.9284,  0.5260,  1.5330],
         [-0.5760, -0.0452,  0.4230,  ..., -0.7179, -0.7858,  1.6879]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-0.0123,  0.5761,  0.2209,  ..., -0.1457, -0.7538,  0.1761],
          [-0.0705,  0.9215,  0.7990,  ..., -0.1027,  1.1061, -2.5200]],

         [[-1.1465, -1.5578, -0.6984,  ...,  0.0289, -2.1112, -0.8728],
          [ 0.6506, -1.6966,  1.4463,  ...,  1.0310,  0.4824, -0.2291]],

         [[-1.0361, -1.8192, -2.3055,  ..., -0.2195, -1.1732,  0.3182],
          [-0.5841, -0.0227,  3.0901,  ...,  1.5286, -1.5941,  1.1762]],

         ...,

         [[-0.7992,  0.0886,  0.4887,  ...,  0.7859, -1.0127, -0.2676],
          [-0.3055,  0.6270, -3.0705,  ..., -1.7941,  0.4835,  1.3780]],

         [[-1.4692, -0.9135, -0.2802,  ...,  0.1197, -0.7532,  0.0731],
          [ 0.6096, -1.0893, -0.6959,  ..., -0.9691,  0.3500,  1.8863]],

         [[-0.5760, -0.0452,  0.4230,  ...,  0.8851,  0.3078,  0.8106],
          [-1.1804,  0.9512,  0.3169,  ..., -0.7179, -0.7858,  1.6879]]],


        [[[-0.0123,  0.5761,  0.2209,  ..., -0.1457, -0.7538,  0.1761],
          [-0.0705,  0.9215,  0.7990,  ..., -0.1027,  1.1061, -2.5200]],

         [[-0.3700, -1.9754, -0.7315,  ...,  0.5756, -1.5559,  0.0326],
          [ 1.4229,  2.3970, -0.4516,  ...,  0.2293,  0.6996,  3.1299]],

         [[-0.6252,  0.2879, -1.4036,  ...,  0.5306, -0.5608,  1.1861],
          [-2.5980,  0.2673,  3.3016,  ..., -2.0560, -2.4623, -0.9584]],

         ...,

         [[-1.1306, -1.4343, -1.4422,  ...,  0.3918, -1.5336, -0.5026],
          [ 1.8587,  0.8501, -1.2402,  ..., -1.6115, -0.0475,  1.3975]],

         [[-0.9816, -1.4909, -1.0086,  ...,  0.2956,  0.0351, -1.0685],
          [-0.6594, -0.0133, -1.1863,  ..., -0.9284,  0.5260,  1.5330]],

         [[-0.5760, -0.0452,  0.4230,  ...,  0.8851,  0.3078,  0.8106],
          [-1.1804,  0.9512,  0.3169,  ..., -0.7179, -0.7858,  1.6879]]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-0.5911, -0.4682, -0.4314,  ...,  0.0366, -1.0405, -0.0579],
          [-1.1141, -1.4785, -0.6477,  ...,  0.0631, -1.9950, -0.7912],
          [-0.7059, -0.9954, -1.3923,  ..., -0.1557, -0.9998,  0.2774],
          ...,
          [-0.6300,  0.1252,  0.3486,  ...,  0.5692, -0.9417, -0.1399],
          [-1.0593, -0.6395, -0.3343,  ...,  0.0571, -0.7842,  0.1042],
          [-0.2081,  0.2785,  0.0613,  ..., -0.0353, -0.8389, -0.0026]],

         [[-0.3266,  0.6360,  0.0214,  ...,  0.1677,  0.3883,  0.4382],
          [-0.6299,  0.6012,  0.7379,  ...,  0.2989, -0.1569,  0.9508],
          [-0.4097,  0.6374, -0.1589,  ...,  0.0258, -0.0364,  0.9990],
          ...,
          [-0.0088,  0.1378, -0.4819,  ..., -0.1261,  0.2908,  1.0980],
          [-0.5584,  0.9932, -0.1105,  ...,  0.2486,  0.4005,  0.6046],
          [-0.4516,  0.8092, -0.0513,  ...,  0.1182,  0.4294,  0.4781]]],


        [[[-0.6850, -0.2168, -0.6087,  ...,  0.1402, -0.7267, -0.1502],
          [-0.2084, -0.1325, -0.1151,  ...,  0.0938, -0.8963,  0.1296],
          [-0.3665,  0.3408, -0.6133,  ...,  0.1938, -0.6681,  0.5929],
          ...,
          [-1.0985, -1.2861, -1.2531,  ...,  0.3609, -1.4022, -0.4444],
          [-0.9787, -1.4207, -0.9908,  ...,  0.2989, -0.0159, -1.0060],
          [-0.1970,  0.2371, -0.0115,  ...,  0.0049, -0.8135,  0.1201]],

         [[ 0.0907,  0.7927, -0.0524,  ..., -0.4610, -0.4295,  0.4391],
          [-0.2352,  1.0586, -0.1117,  ..., -0.4943, -0.6546,  0.6617],
          [ 0.8269,  1.7861, -1.1207,  ..., -0.0885,  0.2791,  1.4721],
          ...,
          [-0.1038,  0.9623, -0.7062,  ..., -0.3340, -0.3050,  0.8792],
          [ 0.0507,  0.6634, -0.2642,  ..., -0.4959, -0.7944,  0.7687],
          [ 0.3364,  0.9047, -0.2037,  ..., -0.3705, -0.2893,  0.6787]]]],
       grad_fn=<ScaledDotProductFlashAttentionForCpuBackward0>)
tensor([[[[-0.5911, -0.4682, -0.4314,  ...,  0.0366, -1.0405, -0.0579],
          [-0.3266,  0.6360,  0.0214,  ...,  0.1677,  0.3883,  0.4382]],

         [[-1.1141, -1.4785, -0.6477,  ...,  0.0631, -1.9950, -0.7912],
          [-0.6299,  0.6012,  0.7379,  ...,  0.2989, -0.1569,  0.9508]],

         [[-0.7059, -0.9954, -1.3923,  ..., -0.1557, -0.9998,  0.2774],
          [-0.4097,  0.6374, -0.1589,  ...,  0.0258, -0.0364,  0.9990]],

         ...,

         [[-0.6300,  0.1252,  0.3486,  ...,  0.5692, -0.9417, -0.1399],
          [-0.0088,  0.1378, -0.4819,  ..., -0.1261,  0.2908,  1.0980]],

         [[-1.0593, -0.6395, -0.3343,  ...,  0.0571, -0.7842,  0.1042],
          [-0.5584,  0.9932, -0.1105,  ...,  0.2486,  0.4005,  0.6046]],

         [[-0.2081,  0.2785,  0.0613,  ..., -0.0353, -0.8389, -0.0026],
          [-0.4516,  0.8092, -0.0513,  ...,  0.1182,  0.4294,  0.4781]]],


        [[[-0.6850, -0.2168, -0.6087,  ...,  0.1402, -0.7267, -0.1502],
          [ 0.0907,  0.7927, -0.0524,  ..., -0.4610, -0.4295,  0.4391]],

         [[-0.2084, -0.1325, -0.1151,  ...,  0.0938, -0.8963,  0.1296],
          [-0.2352,  1.0586, -0.1117,  ..., -0.4943, -0.6546,  0.6617]],

         [[-0.3665,  0.3408, -0.6133,  ...,  0.1938, -0.6681,  0.5929],
          [ 0.8269,  1.7861, -1.1207,  ..., -0.0885,  0.2791,  1.4721]],

         ...,

         [[-1.0985, -1.2861, -1.2531,  ...,  0.3609, -1.4022, -0.4444],
          [-0.1038,  0.9623, -0.7062,  ..., -0.3340, -0.3050,  0.8792]],

         [[-0.9787, -1.4207, -0.9908,  ...,  0.2989, -0.0159, -1.0060],
          [ 0.0507,  0.6634, -0.2642,  ..., -0.4959, -0.7944,  0.7687]],

         [[-0.1970,  0.2371, -0.0115,  ...,  0.0049, -0.8135,  0.1201],
          [ 0.3364,  0.9047, -0.2037,  ..., -0.3705, -0.2893,  0.6787]]]],
       grad_fn=<TransposeBackward0>)
tensor([[[-0.9552,  0.6594, -6.5403,  ..., -0.7144,  0.0906,  0.3369],
         [-2.5251,  1.3955, -0.8914,  ..., -2.1363,  0.0271,  1.1132],
         [-3.7148,  0.6796, -0.8710,  ..., -2.6492,  0.5694, -0.1085],
         ...,
         [-2.2403, -0.7594,  0.5414,  ..., -3.0426,  0.8895, -0.0546],
         [-1.6945, -0.6326, -0.8632,  ..., -4.0678,  1.7219,  0.6481],
         [-2.9625,  0.7451, -0.8037,  ..., -2.5048,  0.3125,  0.5537]],

        [[-0.5150,  0.8150, -6.5015,  ..., -0.5377, -0.4171,  0.1350],
         [-2.9979,  1.0930, -0.2619,  ..., -3.1811, -1.0048, -1.8349],
         [-2.8788,  0.5405, -0.0789,  ..., -2.3969, -0.7016, -0.7332],
         ...,
         [-1.7194,  1.5158,  1.0070,  ..., -2.8931, -2.3309,  1.1685],
         [ 0.0717, -0.1039,  0.5084,  ..., -2.1932,  0.0751,  2.8236],
         [-2.4774,  0.7563, -0.7502,  ..., -2.1312, -0.0685,  0.8700]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[[-0.0455,  0.6529,  0.6297,  ...,  0.4139, -0.9381,  0.6769],
         [-3.1104, -3.7282, -2.3953,  ..., -0.9155, -0.8280, -1.7070],
         [-0.9324, -2.9333, -2.3249,  ..., -0.8455, -0.0326, -0.6998],
         ...,
         [-1.9000, -1.1028, -1.1281,  ..., -0.2809,  2.0206, -1.0802],
         [-1.1088, -1.0420, -2.4026,  ..., -0.4478,  0.7391, -0.0354],
         [ 0.7349,  0.6742, -2.6697,  ..., -0.5114,  1.5155,  2.0246]],

        [[-0.0799,  0.7813,  0.4918,  ...,  0.6888, -0.7680,  0.9805],
         [-2.8720, -1.0602, -2.3610,  ..., -2.1143,  0.9664, -1.1212],
         [-1.4705, -2.1384, -1.9955,  ..., -0.9722,  1.5909, -0.1668],
         ...,
         [-2.9884, -1.1566, -2.5215,  ...,  1.1460,  0.7120, -0.6320],
         [-3.3666, -0.7966, -3.3154,  ...,  0.5316,  1.7058,  2.1950],
         [ 0.6626,  0.8537, -2.7251,  ..., -0.0901,  1.5883,  2.3840]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.0455,  0.6529,  0.6297,  ...,  0.4139, -0.9381,  0.6769],
         [-3.1104, -3.7282, -2.3953,  ..., -0.9155, -0.8280, -1.7070],
         [-0.9324, -2.9333, -2.3249,  ..., -0.8455, -0.0326, -0.6998],
         ...,
         [-1.9000, -1.1028, -1.1281,  ..., -0.2809,  2.0206, -1.0802],
         [-1.1088, -1.0420, -2.4026,  ..., -0.4478,  0.7391, -0.0354],
         [ 0.7349,  0.6742, -2.6697,  ..., -0.5114,  1.5155,  2.0246]],

        [[-0.0799,  0.7813,  0.4918,  ...,  0.6888, -0.7680,  0.9805],
         [-2.8720, -1.0602, -2.3610,  ..., -2.1143,  0.9664, -1.1212],
         [-1.4705, -2.1384, -1.9955,  ..., -0.9722,  1.5909, -0.1668],
         ...,
         [-2.9884, -1.1566, -2.5215,  ...,  1.1460,  0.7120, -0.6320],
         [-3.3666, -0.7966, -3.3154,  ...,  0.5316,  1.7058,  2.1950],
         [ 0.6626,  0.8537, -2.7251,  ..., -0.0901,  1.5883,  2.3840]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-0.0455,  0.6529,  0.6297,  ...,  1.0165, -1.6055,  0.0557],
          [-0.9141, -1.5101, -0.8415,  ...,  0.4139, -0.9381,  0.6769]],

         [[-3.1104, -3.7282, -2.3953,  ..., -1.6195,  0.7426, -3.2794],
          [-1.2694,  0.3821, -0.5687,  ..., -0.9155, -0.8280, -1.7070]],

         [[-0.9324, -2.9333, -2.3249,  ..., -1.0254,  1.8158, -1.8835],
          [-1.5265, -0.3901,  0.2734,  ..., -0.8455, -0.0326, -0.6998]],

         ...,

         [[-1.9000, -1.1028, -1.1281,  ..., -1.2688, -0.0851, -2.3190],
          [-2.4374,  0.0718, -2.7276,  ..., -0.2809,  2.0206, -1.0802]],

         [[-1.1088, -1.0420, -2.4026,  ..., -1.0658,  0.1932, -1.7012],
          [-2.3622, -0.5291, -1.9931,  ..., -0.4478,  0.7391, -0.0354]],

         [[ 0.7349,  0.6742, -2.6697,  ..., -1.4630, -0.1686, -2.5682],
          [-0.1401, -0.9712, -2.3801,  ..., -0.5114,  1.5155,  2.0246]]],


        [[[-0.0799,  0.7813,  0.4918,  ...,  1.2364, -1.9500, -0.1275],
          [-0.4080, -1.5069, -0.8504,  ...,  0.6888, -0.7680,  0.9805]],

         [[-2.8720, -1.0602, -2.3610,  ..., -2.3225, -0.0351, -2.7432],
          [-0.2305, -0.5940, -1.1570,  ..., -2.1143,  0.9664, -1.1212]],

         [[-1.4705, -2.1384, -1.9955,  ..., -0.6524, -1.8025, -1.8321],
          [-1.7742, -0.6800, -0.2172,  ..., -0.9722,  1.5909, -0.1668]],

         ...,

         [[-2.9884, -1.1566, -2.5215,  ..., -0.5054, -1.0314, -3.4883],
          [-1.9535,  0.5573, -2.1564,  ...,  1.1460,  0.7120, -0.6320]],

         [[-3.3666, -0.7966, -3.3154,  ...,  0.7587, -0.6289, -3.4848],
          [-1.4099, -2.0919, -1.5870,  ...,  0.5316,  1.7058,  2.1950]],

         [[ 0.6626,  0.8537, -2.7251,  ..., -1.1831, -0.7083, -2.7717],
          [ 0.4486, -1.1639, -2.1203,  ..., -0.0901,  1.5883,  2.3840]]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.8947,  0.0412, -1.2359,  ...,  0.4410, -0.3965,  0.0106],
         [-1.9196,  0.3326,  0.8482,  ..., -1.5790, -1.1817, -1.0156],
         [-2.1664,  0.3959,  0.7476,  ..., -2.1767, -0.6488,  0.1889],
         ...,
         [-1.6009,  0.4887, -0.4818,  ..., -1.1268,  0.4111,  0.7892],
         [-0.1528,  1.1728, -0.5164,  ..., -0.4340,  0.1499,  1.6704],
         [ 1.0253,  1.4222, -0.1805,  ..., -0.6130, -0.5380,  1.6164]],

        [[-0.6736, -0.0718, -1.1724,  ...,  0.2001, -0.5481,  0.0232],
         [-1.8698, -1.2184,  0.2913,  ..., -1.1398, -1.3523, -0.7851],
         [-1.3725, -0.8212,  0.1984,  ..., -1.8218, -1.4800, -0.2956],
         ...,
         [-0.5946,  0.5680,  0.8938,  ..., -1.6653,  0.8218,  1.1902],
         [ 1.2800,  1.9566,  0.2540,  ..., -1.2290,  0.5257,  1.2667],
         [ 1.3511,  1.3329, -0.0782,  ..., -0.8454, -0.7400,  1.5966]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.8947,  0.0412, -1.2359,  ...,  0.4410, -0.3965,  0.0106],
         [-1.9196,  0.3326,  0.8482,  ..., -1.5790, -1.1817, -1.0156],
         [-2.1664,  0.3959,  0.7476,  ..., -2.1767, -0.6488,  0.1889],
         ...,
         [-1.6009,  0.4887, -0.4818,  ..., -1.1268,  0.4111,  0.7892],
         [-0.1528,  1.1728, -0.5164,  ..., -0.4340,  0.1499,  1.6704],
         [ 1.0253,  1.4222, -0.1805,  ..., -0.6130, -0.5380,  1.6164]],

        [[-0.6736, -0.0718, -1.1724,  ...,  0.2001, -0.5481,  0.0232],
         [-1.8698, -1.2184,  0.2913,  ..., -1.1398, -1.3523, -0.7851],
         [-1.3725, -0.8212,  0.1984,  ..., -1.8218, -1.4800, -0.2956],
         ...,
         [-0.5946,  0.5680,  0.8938,  ..., -1.6653,  0.8218,  1.1902],
         [ 1.2800,  1.9566,  0.2540,  ..., -1.2290,  0.5257,  1.2667],
         [ 1.3511,  1.3329, -0.0782,  ..., -0.8454, -0.7400,  1.5966]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-0.8947,  0.0412, -1.2359,  ...,  1.5140, -1.9812, -2.5532],
          [-0.2951, -1.6086, -0.6381,  ...,  0.4410, -0.3965,  0.0106]],

         [[-1.9196,  0.3326,  0.8482,  ..., -2.3348,  1.3935,  1.1452],
          [-0.5277,  0.1234,  0.7865,  ..., -1.5790, -1.1817, -1.0156]],

         [[-2.1664,  0.3959,  0.7476,  ..., -2.0817,  0.2852,  0.8173],
          [-0.8414,  0.5154, -0.4553,  ..., -2.1767, -0.6488,  0.1889]],

         ...,

         [[-1.6009,  0.4887, -0.4818,  ..., -1.8165,  1.4764,  0.5091],
          [-0.6869,  0.4007, -1.5818,  ..., -1.1268,  0.4111,  0.7892]],

         [[-0.1528,  1.1728, -0.5164,  ..., -1.3611,  1.0621,  1.1810],
          [-0.7595, -0.1699, -1.5305,  ..., -0.4340,  0.1499,  1.6704]],

         [[ 1.0253,  1.4222, -0.1805,  ..., -0.6989,  0.4721,  2.6129],
          [-1.2381, -0.4573, -1.7561,  ..., -0.6130, -0.5380,  1.6164]]],


        [[[-0.6736, -0.0718, -1.1724,  ...,  1.4816, -1.7920, -2.5177],
          [-0.3929, -1.5120, -0.5353,  ...,  0.2001, -0.5481,  0.0232]],

         [[-1.8698, -1.2184,  0.2913,  ..., -1.5227,  1.9764,  0.6389],
          [-0.4202,  0.4572, -1.0780,  ..., -1.1398, -1.3523, -0.7851]],

         [[-1.3725, -0.8212,  0.1984,  ..., -2.1553,  1.7041,  0.7166],
          [-1.0124,  0.9351, -0.0954,  ..., -1.8218, -1.4800, -0.2956]],

         ...,

         [[-0.5946,  0.5680,  0.8938,  ..., -2.1904,  1.7986,  1.0902],
          [-1.3820,  1.0268, -1.0041,  ..., -1.6653,  0.8218,  1.1902]],

         [[ 1.2800,  1.9566,  0.2540,  ..., -1.6180,  1.6176,  2.5636],
          [-2.0592,  0.7059, -1.3359,  ..., -1.2290,  0.5257,  1.2667]],

         [[ 1.3511,  1.3329, -0.0782,  ..., -0.5836,  0.6491,  2.6554],
          [-1.3686, -0.2348, -1.7438,  ..., -0.8454, -0.7400,  1.5966]]]],
       grad_fn=<ViewBackward0>)
tensor([[[ 3.8307e-01,  1.8837e-02, -1.0314e+00,  ...,  6.4335e-01,
          -3.1830e-01, -1.7296e+00],
         [ 1.5897e+00,  1.3689e-01, -6.8915e-01,  ...,  1.5973e+00,
           1.1907e+00, -9.0454e-01],
         [-5.8138e-01,  6.7943e-01, -1.3203e+00,  ..., -5.3627e-01,
          -1.0456e+00, -1.6301e+00],
         ...,
         [-1.7466e-01,  3.0707e-02,  7.5225e-01,  ..., -1.3217e+00,
          -1.3415e+00, -3.8328e-01],
         [ 1.5170e-01,  5.1089e-01,  1.3993e-01,  ..., -1.6600e-01,
          -6.5011e-01,  2.1798e-02],
         [-2.4311e-01,  1.6726e+00,  1.6682e-01,  ...,  1.3441e-03,
          -1.6754e+00,  3.1771e-01]],

        [[ 4.7844e-01, -3.1772e-01, -1.0617e+00,  ...,  6.4928e-01,
          -3.2944e-01, -2.4185e+00],
         [ 5.9122e-01, -2.2648e-01,  1.7474e-01,  ..., -1.8623e+00,
          -1.1230e+00,  3.5013e-01],
         [ 1.1642e-01,  1.2460e+00,  7.7941e-02,  ...,  6.4975e-01,
          -7.3862e-01, -2.1510e+00],
         ...,
         [ 1.1291e+00, -1.3637e+00, -1.5779e+00,  ...,  1.7637e+00,
           9.1331e-01, -1.7033e+00],
         [ 1.5909e+00, -1.4922e+00,  1.0060e+00,  ...,  9.8096e-01,
           8.6736e-01, -2.3894e+00],
         [ 1.8451e-01,  1.2740e+00,  4.2857e-01,  ...,  6.2708e-01,
          -1.3601e+00, -3.9984e-01]]], grad_fn=<ViewBackward0>)
tensor([[[ 3.8307e-01,  1.8837e-02, -1.0314e+00,  ...,  6.4335e-01,
          -3.1830e-01, -1.7296e+00],
         [ 1.5897e+00,  1.3689e-01, -6.8915e-01,  ...,  1.5973e+00,
           1.1907e+00, -9.0454e-01],
         [-5.8138e-01,  6.7943e-01, -1.3203e+00,  ..., -5.3627e-01,
          -1.0456e+00, -1.6301e+00],
         ...,
         [-1.7466e-01,  3.0707e-02,  7.5225e-01,  ..., -1.3217e+00,
          -1.3415e+00, -3.8328e-01],
         [ 1.5170e-01,  5.1089e-01,  1.3993e-01,  ..., -1.6600e-01,
          -6.5011e-01,  2.1798e-02],
         [-2.4311e-01,  1.6726e+00,  1.6682e-01,  ...,  1.3441e-03,
          -1.6754e+00,  3.1771e-01]],

        [[ 4.7844e-01, -3.1772e-01, -1.0617e+00,  ...,  6.4928e-01,
          -3.2944e-01, -2.4185e+00],
         [ 5.9122e-01, -2.2648e-01,  1.7474e-01,  ..., -1.8623e+00,
          -1.1230e+00,  3.5013e-01],
         [ 1.1642e-01,  1.2460e+00,  7.7941e-02,  ...,  6.4975e-01,
          -7.3862e-01, -2.1510e+00],
         ...,
         [ 1.1291e+00, -1.3637e+00, -1.5779e+00,  ...,  1.7637e+00,
           9.1331e-01, -1.7033e+00],
         [ 1.5909e+00, -1.4922e+00,  1.0060e+00,  ...,  9.8096e-01,
           8.6736e-01, -2.3894e+00],
         [ 1.8451e-01,  1.2740e+00,  4.2857e-01,  ...,  6.2708e-01,
          -1.3601e+00, -3.9984e-01]]], grad_fn=<ViewBackward0>)
tensor([[[[ 3.8307e-01,  1.8837e-02, -1.0314e+00,  ..., -8.0955e-01,
           -2.8557e-01, -2.3318e-01],
          [-9.9661e-02,  3.1722e-01, -3.0517e-01,  ...,  6.4335e-01,
           -3.1830e-01, -1.7296e+00]],

         [[ 1.5897e+00,  1.3689e-01, -6.8915e-01,  ..., -7.9549e-01,
           -6.9279e-01, -1.8082e-01],
          [-9.9201e-01,  9.4938e-01,  4.4198e-02,  ...,  1.5973e+00,
            1.1907e+00, -9.0454e-01]],

         [[-5.8138e-01,  6.7943e-01, -1.3203e+00,  ..., -1.8905e+00,
            1.6226e-01, -1.2953e+00],
          [-7.0312e-01, -6.4926e-01, -5.0913e-01,  ..., -5.3627e-01,
           -1.0456e+00, -1.6301e+00]],

         ...,

         [[-1.7466e-01,  3.0707e-02,  7.5225e-01,  ..., -1.9281e+00,
            1.1489e+00, -2.4530e-01],
          [-7.6225e-02,  8.5814e-01, -1.5467e+00,  ..., -1.3217e+00,
           -1.3415e+00, -3.8328e-01]],

         [[ 1.5170e-01,  5.1089e-01,  1.3993e-01,  ..., -2.4168e+00,
            3.3385e-01, -6.2115e-02],
          [-1.6390e+00, -1.6085e-01, -1.9118e+00,  ..., -1.6600e-01,
           -6.5011e-01,  2.1798e-02]],

         [[-2.4311e-01,  1.6726e+00,  1.6682e-01,  ..., -1.0481e+00,
           -2.7634e+00,  2.2741e-01],
          [-1.4603e+00,  3.1239e-02,  3.8892e-01,  ...,  1.3441e-03,
           -1.6754e+00,  3.1771e-01]]],


        [[[ 4.7844e-01, -3.1772e-01, -1.0617e+00,  ..., -7.8511e-01,
           -3.2510e-01, -2.1300e-01],
          [-3.9893e-01,  6.1469e-01, -3.9206e-01,  ...,  6.4928e-01,
           -3.2944e-01, -2.4185e+00]],

         [[ 5.9122e-01, -2.2648e-01,  1.7474e-01,  ..., -1.6488e+00,
           -8.6854e-01, -8.0783e-01],
          [-2.1516e+00, -2.4247e-01, -8.1713e-01,  ..., -1.8623e+00,
           -1.1230e+00,  3.5013e-01]],

         [[ 1.1642e-01,  1.2460e+00,  7.7941e-02,  ..., -1.3121e+00,
           -6.7044e-01, -1.1324e+00],
          [ 4.3930e-01,  3.4082e-01, -1.2243e+00,  ...,  6.4975e-01,
           -7.3862e-01, -2.1510e+00]],

         ...,

         [[ 1.1291e+00, -1.3637e+00, -1.5779e+00,  ...,  8.5980e-01,
            3.2796e-01, -1.9442e+00],
          [-8.9502e-01,  9.7357e-01,  8.5447e-01,  ...,  1.7637e+00,
            9.1331e-01, -1.7033e+00]],

         [[ 1.5909e+00, -1.4922e+00,  1.0060e+00,  ..., -1.5965e+00,
           -3.9380e-01, -4.3585e-01],
          [-2.2103e+00,  4.4127e-01,  1.1554e+00,  ...,  9.8096e-01,
            8.6736e-01, -2.3894e+00]],

         [[ 1.8451e-01,  1.2740e+00,  4.2857e-01,  ..., -1.1366e+00,
           -2.8409e+00,  4.6711e-01],
          [-1.9576e+00,  2.0176e-01, -4.1035e-02,  ...,  6.2708e-01,
           -1.3601e+00, -3.9984e-01]]]], grad_fn=<ViewBackward0>)
tensor([[[[ 3.7393e-01,  3.7078e-01, -7.5553e-01,  ..., -9.4280e-01,
           -6.5765e-01, -1.4711e-01],
          [ 3.9396e-01,  4.7658e-02, -1.0277e+00,  ..., -8.4926e-01,
           -2.7815e-01, -2.6627e-01],
          [ 9.9586e-01,  2.4414e-01, -8.3459e-01,  ..., -8.5795e-01,
           -4.3860e-01, -2.1582e-01],
          ...,
          [ 1.4504e+00, -1.2077e-01, -4.8160e-01,  ..., -1.6701e+00,
            6.3807e-01, -8.3788e-02],
          [ 3.8821e-01,  1.4150e-01, -1.5051e-01,  ..., -1.5785e+00,
            4.1589e-01, -1.8024e-01],
          [ 1.0467e-01,  5.5196e-01,  1.0818e-01,  ..., -2.0373e+00,
            8.4395e-02, -6.9034e-02]],

         [[-7.7230e-01, -1.1852e-01, -8.0274e-02,  ..., -1.3794e-03,
           -5.2249e-01, -4.2095e-01],
          [-4.6209e-01, -1.2680e-01, -2.5711e-01,  ...,  1.3235e-01,
           -4.1385e-01, -1.3744e+00],
          [-1.7985e-01, -2.6037e-01,  3.5678e-01,  ...,  2.4736e-01,
           -1.6626e-01, -7.0940e-01],
          ...,
          [-6.2069e-01,  2.5298e-01, -7.8416e-01,  ...,  8.6187e-02,
           -6.9561e-01, -8.5675e-01],
          [-9.1553e-01,  1.4648e-01, -5.5621e-02,  ...,  1.8643e-01,
           -1.0965e+00, -4.8097e-01],
          [-1.0678e+00,  9.0885e-02, -4.5400e-02,  ...,  1.5424e-01,
           -1.1762e+00, -2.9385e-01]]],


        [[[ 4.9445e-01,  9.6665e-03, -5.7117e-01,  ..., -9.7640e-01,
           -9.0948e-01, -1.3761e-01],
          [ 4.7672e-01, -2.3672e-01, -8.9246e-01,  ..., -8.9927e-01,
           -3.9088e-01, -3.1635e-01],
          [ 4.5120e-01, -5.8035e-02, -6.8736e-01,  ..., -1.0352e+00,
           -4.7996e-01, -4.5624e-01],
          ...,
          [ 3.3361e-01,  1.5918e-03, -1.0108e+00,  ..., -1.5704e+00,
           -3.9079e-01, -2.1742e-01],
          [ 1.0123e+00, -1.1147e+00, -1.3032e+00,  ...,  3.4082e-01,
            1.3298e-01, -1.5481e+00],
          [ 1.4273e+00, -1.2951e+00,  5.7719e-01,  ..., -1.2490e+00,
           -4.1355e-01, -5.8558e-01]],

         [[-1.0884e+00,  2.0838e-01, -5.0948e-01,  ...,  2.2845e-01,
           -8.2075e-01, -1.1496e+00],
          [-3.4296e-01,  5.0833e-01, -8.6644e-01,  ...,  3.3008e-01,
           -5.5070e-01, -1.8440e+00],
          [-3.2602e-01,  6.0671e-01, -7.5711e-01,  ...,  4.7493e-01,
           -4.6589e-01, -2.0454e+00],
          ...,
          [-1.7399e+00,  4.6359e-01,  6.5213e-01,  ...,  9.0025e-01,
            3.1643e-01, -2.0784e+00],
          [-1.8375e+00,  2.6226e-01,  3.7726e-02,  ...,  6.6286e-01,
           -1.0588e+00, -7.6781e-01],
          [-1.6402e+00,  3.1250e-01, -6.6194e-04,  ...,  6.6493e-01,
           -9.1110e-01, -1.0382e+00]]]],
       grad_fn=<ScaledDotProductFlashAttentionForCpuBackward0>)
tensor([[[[ 3.7393e-01,  3.7078e-01, -7.5553e-01,  ..., -9.4280e-01,
           -6.5765e-01, -1.4711e-01],
          [-7.7230e-01, -1.1852e-01, -8.0274e-02,  ..., -1.3794e-03,
           -5.2249e-01, -4.2095e-01]],

         [[ 3.9396e-01,  4.7658e-02, -1.0277e+00,  ..., -8.4926e-01,
           -2.7815e-01, -2.6627e-01],
          [-4.6209e-01, -1.2680e-01, -2.5711e-01,  ...,  1.3235e-01,
           -4.1385e-01, -1.3744e+00]],

         [[ 9.9586e-01,  2.4414e-01, -8.3459e-01,  ..., -8.5795e-01,
           -4.3860e-01, -2.1582e-01],
          [-1.7985e-01, -2.6037e-01,  3.5678e-01,  ...,  2.4736e-01,
           -1.6626e-01, -7.0940e-01]],

         ...,

         [[ 1.4504e+00, -1.2077e-01, -4.8160e-01,  ..., -1.6701e+00,
            6.3807e-01, -8.3788e-02],
          [-6.2069e-01,  2.5298e-01, -7.8416e-01,  ...,  8.6187e-02,
           -6.9561e-01, -8.5675e-01]],

         [[ 3.8821e-01,  1.4150e-01, -1.5051e-01,  ..., -1.5785e+00,
            4.1589e-01, -1.8024e-01],
          [-9.1553e-01,  1.4648e-01, -5.5621e-02,  ...,  1.8643e-01,
           -1.0965e+00, -4.8097e-01]],

         [[ 1.0467e-01,  5.5196e-01,  1.0818e-01,  ..., -2.0373e+00,
            8.4395e-02, -6.9034e-02],
          [-1.0678e+00,  9.0885e-02, -4.5400e-02,  ...,  1.5424e-01,
           -1.1762e+00, -2.9385e-01]]],


        [[[ 4.9445e-01,  9.6665e-03, -5.7117e-01,  ..., -9.7640e-01,
           -9.0948e-01, -1.3761e-01],
          [-1.0884e+00,  2.0838e-01, -5.0948e-01,  ...,  2.2845e-01,
           -8.2075e-01, -1.1496e+00]],

         [[ 4.7672e-01, -2.3672e-01, -8.9246e-01,  ..., -8.9927e-01,
           -3.9088e-01, -3.1635e-01],
          [-3.4296e-01,  5.0833e-01, -8.6644e-01,  ...,  3.3008e-01,
           -5.5070e-01, -1.8440e+00]],

         [[ 4.5120e-01, -5.8035e-02, -6.8736e-01,  ..., -1.0352e+00,
           -4.7996e-01, -4.5624e-01],
          [-3.2602e-01,  6.0671e-01, -7.5711e-01,  ...,  4.7493e-01,
           -4.6589e-01, -2.0454e+00]],

         ...,

         [[ 3.3361e-01,  1.5918e-03, -1.0108e+00,  ..., -1.5704e+00,
           -3.9079e-01, -2.1742e-01],
          [-1.7399e+00,  4.6359e-01,  6.5213e-01,  ...,  9.0025e-01,
            3.1643e-01, -2.0784e+00]],

         [[ 1.0123e+00, -1.1147e+00, -1.3032e+00,  ...,  3.4082e-01,
            1.3298e-01, -1.5481e+00],
          [-1.8375e+00,  2.6226e-01,  3.7726e-02,  ...,  6.6286e-01,
           -1.0588e+00, -7.6781e-01]],

         [[ 1.4273e+00, -1.2951e+00,  5.7719e-01,  ..., -1.2490e+00,
           -4.1355e-01, -5.8558e-01],
          [-1.6402e+00,  3.1250e-01, -6.6194e-04,  ...,  6.6493e-01,
           -9.1110e-01, -1.0382e+00]]]], grad_fn=<TransposeBackward0>)
tensor([[-0.3776, -0.2373],
        [-0.3180, -0.2760]], grad_fn=<AddmmBackward0>)
tensor([1, 0])

Full Supervised Finetuning (SFT)#

Before training the model, let’s inspect how many trainable parameters there are. If you’re familiar with Keras, you might have used the model.summary() API before, but it’s not as easy to do the same in Pytorch - luckily, Mase has a module-level pass with this functionality.

from chop.passes.module import report_trainable_parameters_analysis_pass

_, _ = report_trainable_parameters_analysis_pass(mg.model)
+-------------------------------------------------+------------------------+
| Submodule                                       |   Trainable Parameters |
+=================================================+========================+
| bert                                            |                4385920 |
+-------------------------------------------------+------------------------+
| bert.embeddings                                 |                3972864 |
+-------------------------------------------------+------------------------+
| bert.embeddings.word_embeddings                 |                3906816 |
+-------------------------------------------------+------------------------+
| bert.embeddings.token_type_embeddings           |                    256 |
+-------------------------------------------------+------------------------+
| bert.embeddings.position_embeddings             |                  65536 |
+-------------------------------------------------+------------------------+
| bert.embeddings.LayerNorm                       |                    256 |
+-------------------------------------------------+------------------------+
| bert.embeddings.dropout                         |                      0 |
+-------------------------------------------------+------------------------+
| bert.encoder                                    |                 396544 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer                              |                 396544 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0                            |                 198272 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention                  |                  66304 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self             |                  49536 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query       |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key         |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value       |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output           |                  16768 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense     |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dropout   |                      0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.LayerNorm |                    256 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate               |                  66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense         |                  66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output                     |                  65920 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense               |                  65664 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dropout             |                      0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.LayerNorm           |                    256 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1                            |                 198272 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention                  |                  66304 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self             |                  49536 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query       |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key         |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value       |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output           |                  16768 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense     |                  16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dropout   |                      0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.LayerNorm |                    256 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate               |                  66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense         |                  66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output                     |                  65920 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense               |                  65664 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dropout             |                      0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.LayerNorm           |                    256 |
+-------------------------------------------------+------------------------+
| bert.pooler                                     |                  16512 |
+-------------------------------------------------+------------------------+
| bert.pooler.dense                               |                  16512 |
+-------------------------------------------------+------------------------+
| bert.pooler.activation                          |                      0 |
+-------------------------------------------------+------------------------+
| dropout                                         |                      0 |
+-------------------------------------------------+------------------------+
| classifier                                      |                    258 |
+-------------------------------------------------+------------------------+
| crossentropyloss_0                              |                      0 |
+-------------------------------------------------+------------------------+

Total Trainable Parameters: 14480258

From this, we can see the majority of the trainable parameters are in the Embedding layer. We don’t need to train this, so we freeze those parameters in the cell below.

for param in mg.model.bert.embeddings.parameters():
    param.requires_grad = False

To train the model, we rely on the Trainer class from the transformers library, which makes it easy to set up a training loop with any hardware configuration. The get_trainer utility in Mase handles assigning the training arguments to the Trainer class for common use cases, such as in this tutorial.

from chop.tools import get_trainer

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)
/home/zz7522/Projects/mase/src/chop/tools/huggingface.py:157: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
  trainer = Trainer(

Before running any fine tuning, let’s see how the model performs out of the box. Without any fine-tuning, we can see the model just performs a random guess - there are two labels in the dataset, so this corresponds to an accuracy of around 50%.

# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn(
[782/782 01:29]
Evaluation accuracy: 0.49944

Now, run the cell below to execute a single training epoch with the current setup.

trainer.train()
/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn(
[782/782 00:34, Epoch 1/1]
Step Training Loss
500 0.577800

/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn(
TrainOutput(global_step=782, training_loss=0.5414889596612252, metrics={'train_runtime': 34.4035, 'train_samples_per_second': 726.671, 'train_steps_per_second': 22.73, 'total_flos': 0.0, 'train_loss': 0.5414889596612252, 'epoch': 1.0})

Let’s see how much accuracy we get after a single training epoch of full finetuning.

eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn(
Evaluation accuracy: 0.788

We can now export the SFT version of the model to be used in later tutorials.

from pathlib import Path

mg.export(f"{Path.home()}/tutorial_2_sft")
INFO     Exporting MaseGraph to /home/zz7522/tutorial_2_sft.pt, /home/zz7522/tutorial_2_sft.mz
INFO     Exporting GraphModule to /home/zz7522/tutorial_2_sft.pt
INFO     Exporting MaseMetadata to /home/zz7522/tutorial_2_sft.mz
WARNING  Failed to pickle call_function node: finfo
WARNING  cannot pickle 'torch.finfo' object
WARNING  Failed to pickle call_function node: getattr_2
WARNING  cannot pickle 'torch.finfo' object

Parameter Efficient Finetuning (PEFT) with LoRA#

An alternative to full fine-tuning is Parameter Efficient Fine Tuning (PEFT), which uses a small number of trainable parameters to achieve similar performance. LoRA was proposed by a research team at Microsoft in 2021, as an efficient technique for PEFT.

drawing

Consider the standard equation of a linear layer:

\[ y = X W + b \]

The LoRA method involves replacing this with the following, where A and B are low-rank matrices. We freeze the \(W\) parameters, and only allow the optimizer to train the parameters in \(A\) and \(B\).

\[ y = X (W + AB) + b \]

This enables us to achieve accuracies comparable to full fine tuning, while only training a fraction of the parameters. See the paper for more details. We can inject the LoRA adapter into the existing model using the insert_lora_adapter_transform_pass pass in Mase, as follows.

mg, _ = passes.insert_lora_adapter_transform_pass(
    mg,
    pass_args={
        "rank": 6,
        "alpha": 1.0,
        "dropout": 0.5,
    },
)
INFO     Replaced node: bert_encoder_layer_0_attention_self_query, target: bert.encoder.layer.0.attention.self.query with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_0_attention_self_key, target: bert.encoder.layer.0.attention.self.key with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_0_attention_self_value, target: bert.encoder.layer.0.attention.self.value with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_0_attention_output_dense, target: bert.encoder.layer.0.attention.output.dense with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_0_intermediate_dense, target: bert.encoder.layer.0.intermediate.dense with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_0_output_dense, target: bert.encoder.layer.0.output.dense with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_attention_self_query, target: bert.encoder.layer.1.attention.self.query with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_attention_self_key, target: bert.encoder.layer.1.attention.self.key with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_attention_self_value, target: bert.encoder.layer.1.attention.self.value with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_attention_output_dense, target: bert.encoder.layer.1.attention.output.dense with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_intermediate_dense, target: bert.encoder.layer.1.intermediate.dense with LoRALinear module.
INFO     Replaced node: bert_encoder_layer_1_output_dense, target: bert.encoder.layer.1.output.dense with LoRALinear module.
INFO     Replaced node: bert_pooler_dense, target: bert.pooler.dense with LoRALinear module.
INFO     Replaced node: classifier, target: classifier with LoRALinear module.

Similar to before, let’s report the number of trainable parameters.

_, _ = report_trainable_parameters_analysis_pass(mg.model)
+-----------------------------------------------------+------------------------+
| Submodule                                           |   Trainable Parameters |
+=====================================================+========================+
| bert                                                |                 439808 |
+-----------------------------------------------------+------------------------+
| bert.embeddings                                     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.word_embeddings                     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.token_type_embeddings               |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.position_embeddings                 |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.LayerNorm                           |                      0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.dropout                             |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder                                        |                 421888 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer                                  |                 421888 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0                                |                 210944 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention                      |                  71936 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self                 |                  53760 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query           |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.linear    |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.lora_a    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.lora_b    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.dropout   |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key             |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.linear      |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.lora_a      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.lora_b      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.dropout     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value           |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.linear    |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.lora_a    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.lora_b    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.dropout   |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output               |                  18176 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense         |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.linear  |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.lora_a  |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.lora_b  |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.dropout |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dropout       |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.LayerNorm     |                    256 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate                   |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense             |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.linear      |                  65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.lora_a      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.lora_b      |                   3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.dropout     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output                         |                  69632 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense                   |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.linear            |                  65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.lora_a            |                   3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.lora_b            |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.dropout           |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dropout                 |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.LayerNorm               |                    256 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1                                |                 210944 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention                      |                  71936 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self                 |                  53760 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query           |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.linear    |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.lora_a    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.lora_b    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.dropout   |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key             |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.linear      |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.lora_a      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.lora_b      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.dropout     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value           |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.linear    |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.lora_a    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.lora_b    |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.dropout   |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output               |                  18176 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense         |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.linear  |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.lora_a  |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.lora_b  |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.dropout |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dropout       |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.LayerNorm     |                    256 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate                   |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense             |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.linear      |                  65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.lora_a      |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.lora_b      |                   3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.dropout     |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output                         |                  69632 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense                   |                  69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.linear            |                  65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.lora_a            |                   3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.lora_b            |                    768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.dropout           |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dropout                 |                      0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.LayerNorm               |                    256 |
+-----------------------------------------------------+------------------------+
| bert.pooler                                         |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense                                   |                  17920 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.linear                            |                  16384 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.lora_a                            |                    768 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.lora_b                            |                    768 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.dropout                           |                      0 |
+-----------------------------------------------------+------------------------+
| bert.pooler.activation                              |                      0 |
+-----------------------------------------------------+------------------------+
| dropout                                             |                      0 |
+-----------------------------------------------------+------------------------+
| classifier                                          |                   1036 |
+-----------------------------------------------------+------------------------+
| classifier.linear                                   |                    256 |
+-----------------------------------------------------+------------------------+
| classifier.lora_a                                   |                    768 |
+-----------------------------------------------------+------------------------+
| classifier.lora_b                                   |                     12 |
+-----------------------------------------------------+------------------------+
| classifier.dropout                                  |                      0 |
+-----------------------------------------------------+------------------------+
| crossentropyloss_0                                  |                      0 |
+-----------------------------------------------------+------------------------+

Total Trainable Parameters: 3169816

In this case, LoRA reduces the number of trainable parameters by \(4.5\times\)! We’ll run a few more training epochs and evaluate the resulting accuracy.

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
    num_train_epochs=1,
)
trainer.train()

# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
/home/zz7522/Projects/mase/src/chop/tools/huggingface.py:157: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
  trainer = Trainer(
/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn(
[782/782 00:44, Epoch 1/1]
Step Training Loss
500 0.441300

/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn(
/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn(
[782/782 01:02]
Evaluation accuracy: 0.82264

After training is finished, we can run the fuse_lora_weights_transform_pass pass to optimize the model for inference. This pass replaces each LoRALinear instance with an nn.Linear module, where the \(AB\) product added to the original weights matrix. This incurs less kernel invocations when deploying the model, which reduces inference runtime.

mg, _ = passes.fuse_lora_weights_transform_pass(mg)
eval_results = trainer.evaluate()
INFO     Fusing LoRALinear weights for bert.encoder.layer.0.attention.self.query.
INFO     Fusing LoRALinear weights for bert.encoder.layer.0.attention.self.key.
INFO     Fusing LoRALinear weights for bert.encoder.layer.0.attention.self.value.
INFO     Fusing LoRALinear weights for bert.encoder.layer.0.attention.output.dense.
INFO     Fusing LoRALinear weights for bert.encoder.layer.0.intermediate.dense.
INFO     Fusing LoRALinear weights for bert.encoder.layer.0.output.dense.
INFO     Fusing LoRALinear weights for bert.encoder.layer.1.attention.self.query.
INFO     Fusing LoRALinear weights for bert.encoder.layer.1.attention.self.key.
INFO     Fusing LoRALinear weights for bert.encoder.layer.1.attention.self.value.
INFO     Fusing LoRALinear weights for bert.encoder.layer.1.attention.output.dense.
INFO     Fusing LoRALinear weights for bert.encoder.layer.1.intermediate.dense.
INFO     Fusing LoRALinear weights for bert.encoder.layer.1.output.dense.
INFO     Fusing LoRALinear weights for bert.pooler.dense.
INFO     Fusing LoRALinear weights for classifier.
/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn(
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
Evaluation accuracy: 0.82264

Finally, export the finetuned model to be used in future tutorials.

from pathlib import Path

mg.export(f"{Path.home()}/tutorial_2_lora")
INFO     Exporting MaseGraph to /home/zz7522/tutorial_2_lora.pt, /home/zz7522/tutorial_2_lora.mz
INFO     Exporting GraphModule to /home/zz7522/tutorial_2_lora.pt
INFO     Exporting MaseMetadata to /home/zz7522/tutorial_2_lora.mz
WARNING  Failed to pickle call_function node: finfo
WARNING  cannot pickle 'torch.finfo' object
WARNING  Failed to pickle call_function node: getattr_2
WARNING  cannot pickle 'torch.finfo' object

Conclusion#

By adjusting the rank number of LoRA, we can control the trade-off between memory usage and fine-tuned accuracy. Such parameter-efficient fine-tuning techniques are very useful in the area of large language models (LLMs), where the memory requirement for training is a significant bottleneck.