Tutorial 2: Finetuning Bert for Sequence Classification using a LoRA adapter#
When we import a pretrained transformer model from HuggingFace, we receive the encoder/decoder weights, which aren’t that useful on their own - to perform a useful task such as sequence classification, we add a classification head on top of the model and train those weights on the required dataset. In this tutorial, we’ll look at fine tuning a Bert model for sequence classification with two approaches. First, we’ll attempt full Supervised Fine Tuning (SFT). Then, we’ll use the Mase stack to add a LoRA adapter to the model. We’ll look at the effect in memory requirement for training and the achieved accuracy.
checkpoint = "prajjwal1/bert-tiny"
tokenizer_checkpoint = "bert-base-uncased"
dataset_name = "imdb"
Sentiment Analysis with the IMDb Dataset#
The IMDB dataset, introduced in this 2011 paper from Stanford, is commonly used for sentiment analysis in the Natural Language Processing (NLP) community. This is a collection of 50k movie reviews from the IMDb website, labelled as either “positive” or “negative”. Here is an example of a positive review:
I turned over to this film in the middle of the night and very nearly skipped right passed it. It was only because there was nothing else on that I decided to watch it. In the end, I thought it was great.
An interesting storyline, good characters, a clever script and brilliant directing makes this a fine film to sit down and watch. This was, in fact, the first I’d heard of this movie, but I would have been happy to have paid money to see this at the cinema.
My IMDB Rating : 8 out of 10
And a negative review:
its a totally average film with a few semi-alright action sequences that make the plot seem a little better and remind the viewer of the classic van dam films. parts of the plot don’t make sense and seem to be added in to use up time. the end plot is that of a very basic type that doesn’t leave the viewer guessing and any twists are obvious from the beginning. the end scene with the flask backs don’t make sense as they are added in and seem to have little relevance to the history of van dam’s character. not really worth watching again, bit disappointed in the end production, even though it is apparent it was shot on a low budget certain shots and sections in the film are of poor directed quality
The dataset is available from HuggingFace through the datasets
library. We use the get_tokenized_dataset
utility in Mase to automatically tokenize it.
from chop.tools import get_tokenized_dataset
dataset, tokenizer = get_tokenized_dataset(
dataset=dataset_name,
checkpoint=tokenizer_checkpoint,
return_tokenizer=True,
)
INFO Tokenizing dataset imdb with AutoTokenizer for bert-base-uncased.
Using the latest cached version of the dataset since imdb couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /Users/yz10513/.cache/huggingface/datasets/imdb/plain_text/0.0.0/e6281661ce1c48d982bc483cf8a173c1bbeb5d31 (last modified on Sun Dec 1 05:38:09 2024).
Generate a MaseGraph with Custom Arguments#
By inspecting the implementation of the Bert model in HuggingFace, we can see the forward function has a signature similar to the following.
class BertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
...
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
...
By default, the MaseGraph constructor chooses to use the input_ids
argument, ignoring the other optional arguments. However, you can specify which inputs to drive during symbolic tracing using the hf_input_names
argument. In the following cell, we also drive the attention_mask
and labels
inputs. By specifying the labels
argument, we include a nn.CrossEntropyLoss
module at the end of the model to calculate the loss directly.
Task: Remove the
attention_mask
andlabels
arguments from thehf_input_names
list and re-run the following cell. Usemg.draw()
to visualize the graph in each case. Can you see any changes in the graph topology? Can you explain why this happens?
from transformers import AutoModelForSequenceClassification
from chop import MaseGraph
import chop.passes as passes
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.config.problem_type = "single_label_classification"
mg = MaseGraph(
model,
hf_input_names=[
"input_ids",
"attention_mask",
"labels",
],
)
mg, _ = passes.init_metadata_analysis_pass(mg)
mg, _ = passes.add_common_metadata_analysis_pass(mg)
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.
INFO Getting dummy input for prajjwal1/bert-tiny.
/Users/yz10513/anaconda3/envs/mase/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154, 102],
[ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877, 102]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154, 102],
[ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877, 102]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])
tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
tensor([[[ 0.7973, 0.0109, -8.8405, ..., 1.4170, 0.1046, -0.1551],
[-1.1766, 1.2879, -1.0986, ..., 0.4749, -0.5899, 0.8746],
[-2.0560, 0.7748, -0.8909, ..., -0.4034, 0.5352, -1.3657],
...,
[ 0.2317, -0.7896, 0.9634, ..., -0.8037, 0.4834, -0.5868],
[ 0.0243, -1.0235, -1.2771, ..., -2.2378, 1.8530, 0.1558],
[-1.3637, 0.7055, -0.2177, ..., 0.3557, -0.3971, -0.3107]],
[[ 0.7973, 0.0109, -8.8405, ..., 1.4170, 0.1046, -0.1551],
[-2.6940, 0.6198, -0.4564, ..., -1.4367, -1.5705, -3.1260],
[-1.7524, 0.8535, -0.2155, ..., -0.5222, -1.2430, -1.7199],
...,
[-0.0347, 0.7446, 1.4462, ..., -1.1578, -2.6197, 0.2612],
[ 2.4334, -0.3068, 0.8250, ..., 0.1475, 0.1790, 2.2907],
[-1.3637, 0.7055, -0.2177, ..., 0.3557, -0.3971, -0.3107]]],
grad_fn=<NativeLayerNormBackward0>)
tensor([[[ 9.7740e-01, 2.5481e-03, -5.2921e-01, ..., 1.4757e-01,
1.8900e-01, 2.8282e-01],
[-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -8.1892e-01,
1.1978e+00, 2.1808e+00],
[ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -6.3025e-01,
-1.5967e-01, 1.3284e+00],
...,
[ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., -3.0983e-01,
-1.2971e-01, 1.1265e+00],
[ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., -1.3173e+00,
2.2258e-01, 8.8157e-01],
[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,
7.2677e-02, 1.1724e-01]],
[[ 9.7740e-01, 2.5481e-03, -5.2921e-01, ..., 1.4757e-01,
1.8900e-01, 2.8282e-01],
[-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., -5.9181e-01,
1.4373e+00, 2.4267e+00],
[-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -5.9773e-01,
-3.0482e-01, 1.4038e+00],
...,
[ 5.1574e-02, 3.5218e-01, -3.8926e-01, ..., -1.1508e+00,
7.5490e-01, 8.2911e-01],
[ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -1.5233e+00,
-6.0733e-01, 3.3097e-01],
[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,
7.2677e-02, 1.1724e-01]]], grad_fn=<ViewBackward0>)
tensor([[[ 9.7740e-01, 2.5481e-03, -5.2921e-01, ..., 1.4757e-01,
1.8900e-01, 2.8282e-01],
[-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -8.1892e-01,
1.1978e+00, 2.1808e+00],
[ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -6.3025e-01,
-1.5967e-01, 1.3284e+00],
...,
[ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., -3.0983e-01,
-1.2971e-01, 1.1265e+00],
[ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., -1.3173e+00,
2.2258e-01, 8.8157e-01],
[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,
7.2677e-02, 1.1724e-01]],
[[ 9.7740e-01, 2.5481e-03, -5.2921e-01, ..., 1.4757e-01,
1.8900e-01, 2.8282e-01],
[-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., -5.9181e-01,
1.4373e+00, 2.4267e+00],
[-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -5.9773e-01,
-3.0482e-01, 1.4038e+00],
...,
[ 5.1574e-02, 3.5218e-01, -3.8926e-01, ..., -1.1508e+00,
7.5490e-01, 8.2911e-01],
[ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -1.5233e+00,
-6.0733e-01, 3.3097e-01],
[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,
7.2677e-02, 1.1724e-01]]], grad_fn=<ViewBackward0>)
tensor([[[[ 9.7740e-01, 2.5481e-03, -5.2921e-01, ..., -5.2614e-01,
-3.5687e-01, -2.6793e-01],
[ 2.0335e-01, -5.4534e-01, 3.0686e-01, ..., 1.4757e-01,
1.8900e-01, 2.8282e-01]],
[[-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -2.0138e+00,
4.5529e-01, -7.8171e-01],
[ 1.1969e+00, 1.6337e+00, 2.5047e-01, ..., -8.1892e-01,
1.1978e+00, 2.1808e+00]],
[[ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -2.2501e-01,
7.2290e-02, -1.8290e+00],
[ 8.9952e-01, 1.0029e+00, 7.4520e-04, ..., -6.3025e-01,
-1.5967e-01, 1.3284e+00]],
...,
[[ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., 1.1867e+00,
-1.3561e+00, 6.5158e-01],
[ 9.5466e-01, 4.5887e-01, 7.8078e-01, ..., -3.0983e-01,
-1.2971e-01, 1.1265e+00]],
[[ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., 6.5467e-01,
-6.8451e-01, 6.5081e-01],
[ 7.0729e-01, 1.4499e+00, -1.5089e-01, ..., -1.3173e+00,
2.2258e-01, 8.8157e-01]],
[[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., 3.3130e-01,
-3.2756e-01, -6.3130e-01],
[ 8.6773e-01, 2.0996e-01, -3.4332e-01, ..., -1.6159e+00,
7.2677e-02, 1.1724e-01]]],
[[[ 9.7740e-01, 2.5481e-03, -5.2921e-01, ..., -5.2614e-01,
-3.5687e-01, -2.6793e-01],
[ 2.0335e-01, -5.4534e-01, 3.0686e-01, ..., 1.4757e-01,
1.8900e-01, 2.8282e-01]],
[[-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., 1.7508e+00,
-3.6708e-01, -1.3251e+00],
[ 7.9208e-01, -1.3537e-01, 2.3756e-01, ..., -5.9181e-01,
1.4373e+00, 2.4267e+00]],
[[-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -9.6646e-01,
-4.8876e-01, -1.4426e+00],
[ 1.0250e+00, -6.9093e-01, -1.2734e+00, ..., -5.9773e-01,
-3.0482e-01, 1.4038e+00]],
...,
[[ 5.1574e-02, 3.5218e-01, -3.8926e-01, ..., -1.2252e-02,
1.0394e+00, 4.2402e-01],
[-4.7386e-01, 2.6401e+00, 1.7024e+00, ..., -1.1508e+00,
7.5490e-01, 8.2911e-01]],
[[ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -6.1665e-01,
2.7627e-01, -1.2083e+00],
[ 9.3395e-01, -9.7541e-01, -2.5442e-02, ..., -1.5233e+00,
-6.0733e-01, 3.3097e-01]],
[[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., 3.3130e-01,
-3.2756e-01, -6.3130e-01],
[ 8.6773e-01, 2.0996e-01, -3.4332e-01, ..., -1.6159e+00,
7.2677e-02, 1.1724e-01]]]], grad_fn=<ViewBackward0>)
tensor([[[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],
[-0.5842, 0.9588, 1.5642, ..., -1.0731, -0.7330, 0.3132],
[-0.8601, -1.3756, 0.5042, ..., -0.0476, 0.2650, 1.2150],
...,
[ 0.0520, 1.1719, -1.5471, ..., -0.7894, 0.1419, 1.6964],
[ 0.7654, -1.5053, -0.4142, ..., -1.4622, -0.8975, 1.4576],
[-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]],
[[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],
[-1.3806, 0.2626, -0.5207, ..., -1.6714, -0.0554, 1.0225],
[-1.7116, 1.8788, -2.5695, ..., -0.6958, 0.5728, 0.5461],
...,
[-1.3246, 1.2196, -0.3034, ..., -1.1955, -0.6708, 0.5128],
[ 0.9854, 0.8260, 0.2892, ..., -0.6428, 0.3637, 0.4339],
[-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]]],
grad_fn=<ViewBackward0>)
tensor([[[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],
[-0.5842, 0.9588, 1.5642, ..., -1.0731, -0.7330, 0.3132],
[-0.8601, -1.3756, 0.5042, ..., -0.0476, 0.2650, 1.2150],
...,
[ 0.0520, 1.1719, -1.5471, ..., -0.7894, 0.1419, 1.6964],
[ 0.7654, -1.5053, -0.4142, ..., -1.4622, -0.8975, 1.4576],
[-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]],
[[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],
[-1.3806, 0.2626, -0.5207, ..., -1.6714, -0.0554, 1.0225],
[-1.7116, 1.8788, -2.5695, ..., -0.6958, 0.5728, 0.5461],
...,
[-1.3246, 1.2196, -0.3034, ..., -1.1955, -0.6708, 0.5128],
[ 0.9854, 0.8260, 0.2892, ..., -0.6428, 0.3637, 0.4339],
[-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]]],
grad_fn=<ViewBackward0>)
tensor([[[[-0.1709, 0.5230, -0.8713, ..., 0.4365, 0.6238, -0.9414],
[-1.3731, 1.1521, 0.1321, ..., -1.3382, 0.5892, 0.4026]],
[[-0.5842, 0.9588, 1.5642, ..., -1.5431, 0.4999, -1.1350],
[ 0.9615, 0.8694, 0.0998, ..., -1.0731, -0.7330, 0.3132]],
[[-0.8601, -1.3756, 0.5042, ..., 0.9764, -0.8321, -1.0204],
[ 1.5175, 1.1454, 0.7791, ..., -0.0476, 0.2650, 1.2150]],
...,
[[ 0.0520, 1.1719, -1.5471, ..., 1.9402, -1.1294, 0.4793],
[ 1.0053, 0.8099, 1.6415, ..., -0.7894, 0.1419, 1.6964]],
[[ 0.7654, -1.5053, -0.4142, ..., 1.7455, -0.7326, 1.5248],
[ 1.0806, 1.1457, 2.2163, ..., -1.4622, -0.8975, 1.4576]],
[[-1.2008, -0.6008, -1.4608, ..., 2.0905, 1.8849, -1.5708],
[ 1.9999, 0.3493, -0.8524, ..., -1.2105, -0.4289, 0.3827]]],
[[[-0.1709, 0.5230, -0.8713, ..., 0.4365, 0.6238, -0.9414],
[-1.3731, 1.1521, 0.1321, ..., -1.3382, 0.5892, 0.4026]],
[[-1.3806, 0.2626, -0.5207, ..., 1.6517, -0.2316, -1.3171],
[ 0.6812, -0.0090, 0.3803, ..., -1.6714, -0.0554, 1.0225]],
[[-1.7116, 1.8788, -2.5695, ..., 0.4927, -0.4850, -1.0645],
[ 1.2646, 1.6481, 0.9055, ..., -0.6958, 0.5728, 0.5461]],
...,
[[-1.3246, 1.2196, -0.3034, ..., 1.2747, 1.2353, 0.2825],
[ 1.5373, 0.8648, 0.6062, ..., -1.1955, -0.6708, 0.5128]],
[[ 0.9854, 0.8260, 0.2892, ..., 1.3848, -0.0103, -1.0700],
[ 1.3827, 2.9809, 0.0276, ..., -0.6428, 0.3637, 0.4339]],
[[-1.2008, -0.6008, -1.4608, ..., 2.0905, 1.8849, -1.5708],
[ 1.9999, 0.3493, -0.8524, ..., -1.2105, -0.4289, 0.3827]]]],
grad_fn=<ViewBackward0>)
tensor([[[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],
[-1.1465, -1.5578, -0.6984, ..., 1.0310, 0.4824, -0.2291],
[-1.0361, -1.8192, -2.3055, ..., 1.5286, -1.5941, 1.1762],
...,
[-0.7992, 0.0886, 0.4887, ..., -1.7941, 0.4835, 1.3780],
[-1.4692, -0.9135, -0.2802, ..., -0.9691, 0.3500, 1.8863],
[-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]],
[[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],
[-0.3700, -1.9754, -0.7315, ..., 0.2293, 0.6996, 3.1299],
[-0.6252, 0.2879, -1.4036, ..., -2.0560, -2.4623, -0.9584],
...,
[-1.1306, -1.4343, -1.4422, ..., -1.6115, -0.0475, 1.3975],
[-0.9816, -1.4909, -1.0086, ..., -0.9284, 0.5260, 1.5330],
[-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]]],
grad_fn=<ViewBackward0>)
tensor([[[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],
[-1.1465, -1.5578, -0.6984, ..., 1.0310, 0.4824, -0.2291],
[-1.0361, -1.8192, -2.3055, ..., 1.5286, -1.5941, 1.1762],
...,
[-0.7992, 0.0886, 0.4887, ..., -1.7941, 0.4835, 1.3780],
[-1.4692, -0.9135, -0.2802, ..., -0.9691, 0.3500, 1.8863],
[-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]],
[[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],
[-0.3700, -1.9754, -0.7315, ..., 0.2293, 0.6996, 3.1299],
[-0.6252, 0.2879, -1.4036, ..., -2.0560, -2.4623, -0.9584],
...,
[-1.1306, -1.4343, -1.4422, ..., -1.6115, -0.0475, 1.3975],
[-0.9816, -1.4909, -1.0086, ..., -0.9284, 0.5260, 1.5330],
[-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]]],
grad_fn=<ViewBackward0>)
tensor([[[[-0.0123, 0.5761, 0.2209, ..., -0.1457, -0.7538, 0.1761],
[-0.0705, 0.9215, 0.7990, ..., -0.1027, 1.1061, -2.5200]],
[[-1.1465, -1.5578, -0.6984, ..., 0.0289, -2.1112, -0.8728],
[ 0.6506, -1.6966, 1.4463, ..., 1.0310, 0.4824, -0.2291]],
[[-1.0361, -1.8192, -2.3055, ..., -0.2195, -1.1732, 0.3182],
[-0.5841, -0.0227, 3.0901, ..., 1.5286, -1.5941, 1.1762]],
...,
[[-0.7992, 0.0886, 0.4887, ..., 0.7859, -1.0127, -0.2676],
[-0.3055, 0.6270, -3.0705, ..., -1.7941, 0.4835, 1.3780]],
[[-1.4692, -0.9135, -0.2802, ..., 0.1197, -0.7532, 0.0731],
[ 0.6096, -1.0893, -0.6959, ..., -0.9691, 0.3500, 1.8863]],
[[-0.5760, -0.0452, 0.4230, ..., 0.8851, 0.3078, 0.8106],
[-1.1804, 0.9512, 0.3169, ..., -0.7179, -0.7858, 1.6879]]],
[[[-0.0123, 0.5761, 0.2209, ..., -0.1457, -0.7538, 0.1761],
[-0.0705, 0.9215, 0.7990, ..., -0.1027, 1.1061, -2.5200]],
[[-0.3700, -1.9754, -0.7315, ..., 0.5756, -1.5559, 0.0326],
[ 1.4229, 2.3970, -0.4516, ..., 0.2293, 0.6996, 3.1299]],
[[-0.6252, 0.2879, -1.4036, ..., 0.5306, -0.5608, 1.1861],
[-2.5980, 0.2673, 3.3016, ..., -2.0560, -2.4623, -0.9584]],
...,
[[-1.1306, -1.4343, -1.4422, ..., 0.3918, -1.5336, -0.5026],
[ 1.8587, 0.8501, -1.2402, ..., -1.6115, -0.0475, 1.3975]],
[[-0.9816, -1.4909, -1.0086, ..., 0.2956, 0.0351, -1.0685],
[-0.6594, -0.0133, -1.1863, ..., -0.9284, 0.5260, 1.5330]],
[[-0.5760, -0.0452, 0.4230, ..., 0.8851, 0.3078, 0.8106],
[-1.1804, 0.9512, 0.3169, ..., -0.7179, -0.7858, 1.6879]]]],
grad_fn=<ViewBackward0>)
tensor([[[[-0.5911, -0.4682, -0.4314, ..., 0.0366, -1.0405, -0.0579],
[-1.1141, -1.4785, -0.6477, ..., 0.0631, -1.9950, -0.7912],
[-0.7059, -0.9954, -1.3923, ..., -0.1557, -0.9998, 0.2774],
...,
[-0.6300, 0.1252, 0.3486, ..., 0.5692, -0.9417, -0.1399],
[-1.0593, -0.6395, -0.3343, ..., 0.0571, -0.7842, 0.1042],
[-0.2081, 0.2785, 0.0613, ..., -0.0353, -0.8389, -0.0026]],
[[-0.3266, 0.6360, 0.0214, ..., 0.1677, 0.3883, 0.4382],
[-0.6299, 0.6012, 0.7379, ..., 0.2989, -0.1569, 0.9508],
[-0.4097, 0.6374, -0.1589, ..., 0.0258, -0.0364, 0.9990],
...,
[-0.0088, 0.1378, -0.4819, ..., -0.1261, 0.2908, 1.0980],
[-0.5584, 0.9932, -0.1105, ..., 0.2486, 0.4005, 0.6046],
[-0.4516, 0.8092, -0.0513, ..., 0.1182, 0.4294, 0.4781]]],
[[[-0.6850, -0.2168, -0.6087, ..., 0.1402, -0.7267, -0.1502],
[-0.2084, -0.1325, -0.1151, ..., 0.0938, -0.8963, 0.1296],
[-0.3665, 0.3408, -0.6133, ..., 0.1938, -0.6681, 0.5929],
...,
[-1.0985, -1.2861, -1.2531, ..., 0.3609, -1.4022, -0.4444],
[-0.9787, -1.4207, -0.9908, ..., 0.2989, -0.0159, -1.0060],
[-0.1970, 0.2371, -0.0115, ..., 0.0049, -0.8135, 0.1201]],
[[ 0.0907, 0.7927, -0.0524, ..., -0.4610, -0.4295, 0.4391],
[-0.2352, 1.0586, -0.1117, ..., -0.4943, -0.6546, 0.6617],
[ 0.8269, 1.7861, -1.1207, ..., -0.0885, 0.2791, 1.4721],
...,
[-0.1038, 0.9623, -0.7062, ..., -0.3340, -0.3050, 0.8792],
[ 0.0507, 0.6634, -0.2642, ..., -0.4959, -0.7944, 0.7687],
[ 0.3364, 0.9047, -0.2037, ..., -0.3705, -0.2893, 0.6787]]]],
grad_fn=<ScaledDotProductFlashAttentionForCpuBackward0>)
tensor([[[[-0.5911, -0.4682, -0.4314, ..., 0.0366, -1.0405, -0.0579],
[-0.3266, 0.6360, 0.0214, ..., 0.1677, 0.3883, 0.4382]],
[[-1.1141, -1.4785, -0.6477, ..., 0.0631, -1.9950, -0.7912],
[-0.6299, 0.6012, 0.7379, ..., 0.2989, -0.1569, 0.9508]],
[[-0.7059, -0.9954, -1.3923, ..., -0.1557, -0.9998, 0.2774],
[-0.4097, 0.6374, -0.1589, ..., 0.0258, -0.0364, 0.9990]],
...,
[[-0.6300, 0.1252, 0.3486, ..., 0.5692, -0.9417, -0.1399],
[-0.0088, 0.1378, -0.4819, ..., -0.1261, 0.2908, 1.0980]],
[[-1.0593, -0.6395, -0.3343, ..., 0.0571, -0.7842, 0.1042],
[-0.5584, 0.9932, -0.1105, ..., 0.2486, 0.4005, 0.6046]],
[[-0.2081, 0.2785, 0.0613, ..., -0.0353, -0.8389, -0.0026],
[-0.4516, 0.8092, -0.0513, ..., 0.1182, 0.4294, 0.4781]]],
[[[-0.6850, -0.2168, -0.6087, ..., 0.1402, -0.7267, -0.1502],
[ 0.0907, 0.7927, -0.0524, ..., -0.4610, -0.4295, 0.4391]],
[[-0.2084, -0.1325, -0.1151, ..., 0.0938, -0.8963, 0.1296],
[-0.2352, 1.0586, -0.1117, ..., -0.4943, -0.6546, 0.6617]],
[[-0.3665, 0.3408, -0.6133, ..., 0.1938, -0.6681, 0.5929],
[ 0.8269, 1.7861, -1.1207, ..., -0.0885, 0.2791, 1.4721]],
...,
[[-1.0985, -1.2861, -1.2531, ..., 0.3609, -1.4022, -0.4444],
[-0.1038, 0.9623, -0.7062, ..., -0.3340, -0.3050, 0.8792]],
[[-0.9787, -1.4207, -0.9908, ..., 0.2989, -0.0159, -1.0060],
[ 0.0507, 0.6634, -0.2642, ..., -0.4959, -0.7944, 0.7687]],
[[-0.1970, 0.2371, -0.0115, ..., 0.0049, -0.8135, 0.1201],
[ 0.3364, 0.9047, -0.2037, ..., -0.3705, -0.2893, 0.6787]]]],
grad_fn=<TransposeBackward0>)
tensor([[[-0.9552, 0.6594, -6.5403, ..., -0.7144, 0.0906, 0.3369],
[-2.5251, 1.3955, -0.8914, ..., -2.1363, 0.0271, 1.1132],
[-3.7148, 0.6796, -0.8710, ..., -2.6492, 0.5694, -0.1085],
...,
[-2.2403, -0.7594, 0.5414, ..., -3.0426, 0.8895, -0.0546],
[-1.6945, -0.6326, -0.8632, ..., -4.0678, 1.7219, 0.6481],
[-2.9625, 0.7451, -0.8037, ..., -2.5048, 0.3125, 0.5537]],
[[-0.5150, 0.8150, -6.5015, ..., -0.5377, -0.4171, 0.1350],
[-2.9979, 1.0930, -0.2619, ..., -3.1811, -1.0048, -1.8349],
[-2.8788, 0.5405, -0.0789, ..., -2.3969, -0.7016, -0.7332],
...,
[-1.7194, 1.5158, 1.0070, ..., -2.8931, -2.3309, 1.1685],
[ 0.0717, -0.1039, 0.5084, ..., -2.1932, 0.0751, 2.8236],
[-2.4774, 0.7563, -0.7502, ..., -2.1312, -0.0685, 0.8700]]],
grad_fn=<NativeLayerNormBackward0>)
tensor([[[-0.0455, 0.6529, 0.6297, ..., 0.4139, -0.9381, 0.6769],
[-3.1104, -3.7282, -2.3953, ..., -0.9155, -0.8280, -1.7070],
[-0.9324, -2.9333, -2.3249, ..., -0.8455, -0.0326, -0.6998],
...,
[-1.9000, -1.1028, -1.1281, ..., -0.2809, 2.0206, -1.0802],
[-1.1088, -1.0420, -2.4026, ..., -0.4478, 0.7391, -0.0354],
[ 0.7349, 0.6742, -2.6697, ..., -0.5114, 1.5155, 2.0246]],
[[-0.0799, 0.7813, 0.4918, ..., 0.6888, -0.7680, 0.9805],
[-2.8720, -1.0602, -2.3610, ..., -2.1143, 0.9664, -1.1212],
[-1.4705, -2.1384, -1.9955, ..., -0.9722, 1.5909, -0.1668],
...,
[-2.9884, -1.1566, -2.5215, ..., 1.1460, 0.7120, -0.6320],
[-3.3666, -0.7966, -3.3154, ..., 0.5316, 1.7058, 2.1950],
[ 0.6626, 0.8537, -2.7251, ..., -0.0901, 1.5883, 2.3840]]],
grad_fn=<ViewBackward0>)
tensor([[[-0.0455, 0.6529, 0.6297, ..., 0.4139, -0.9381, 0.6769],
[-3.1104, -3.7282, -2.3953, ..., -0.9155, -0.8280, -1.7070],
[-0.9324, -2.9333, -2.3249, ..., -0.8455, -0.0326, -0.6998],
...,
[-1.9000, -1.1028, -1.1281, ..., -0.2809, 2.0206, -1.0802],
[-1.1088, -1.0420, -2.4026, ..., -0.4478, 0.7391, -0.0354],
[ 0.7349, 0.6742, -2.6697, ..., -0.5114, 1.5155, 2.0246]],
[[-0.0799, 0.7813, 0.4918, ..., 0.6888, -0.7680, 0.9805],
[-2.8720, -1.0602, -2.3610, ..., -2.1143, 0.9664, -1.1212],
[-1.4705, -2.1384, -1.9955, ..., -0.9722, 1.5909, -0.1668],
...,
[-2.9884, -1.1566, -2.5215, ..., 1.1460, 0.7120, -0.6320],
[-3.3666, -0.7966, -3.3154, ..., 0.5316, 1.7058, 2.1950],
[ 0.6626, 0.8537, -2.7251, ..., -0.0901, 1.5883, 2.3840]]],
grad_fn=<ViewBackward0>)
tensor([[[[-0.0455, 0.6529, 0.6297, ..., 1.0165, -1.6055, 0.0557],
[-0.9141, -1.5101, -0.8415, ..., 0.4139, -0.9381, 0.6769]],
[[-3.1104, -3.7282, -2.3953, ..., -1.6195, 0.7426, -3.2794],
[-1.2694, 0.3821, -0.5687, ..., -0.9155, -0.8280, -1.7070]],
[[-0.9324, -2.9333, -2.3249, ..., -1.0254, 1.8158, -1.8835],
[-1.5265, -0.3901, 0.2734, ..., -0.8455, -0.0326, -0.6998]],
...,
[[-1.9000, -1.1028, -1.1281, ..., -1.2688, -0.0851, -2.3190],
[-2.4374, 0.0718, -2.7276, ..., -0.2809, 2.0206, -1.0802]],
[[-1.1088, -1.0420, -2.4026, ..., -1.0658, 0.1932, -1.7012],
[-2.3622, -0.5291, -1.9931, ..., -0.4478, 0.7391, -0.0354]],
[[ 0.7349, 0.6742, -2.6697, ..., -1.4630, -0.1686, -2.5682],
[-0.1401, -0.9712, -2.3801, ..., -0.5114, 1.5155, 2.0246]]],
[[[-0.0799, 0.7813, 0.4918, ..., 1.2364, -1.9500, -0.1275],
[-0.4080, -1.5069, -0.8504, ..., 0.6888, -0.7680, 0.9805]],
[[-2.8720, -1.0602, -2.3610, ..., -2.3225, -0.0351, -2.7432],
[-0.2305, -0.5940, -1.1570, ..., -2.1143, 0.9664, -1.1212]],
[[-1.4705, -2.1384, -1.9955, ..., -0.6524, -1.8025, -1.8321],
[-1.7742, -0.6800, -0.2172, ..., -0.9722, 1.5909, -0.1668]],
...,
[[-2.9884, -1.1566, -2.5215, ..., -0.5054, -1.0314, -3.4883],
[-1.9535, 0.5573, -2.1564, ..., 1.1460, 0.7120, -0.6320]],
[[-3.3666, -0.7966, -3.3154, ..., 0.7587, -0.6289, -3.4848],
[-1.4099, -2.0919, -1.5870, ..., 0.5316, 1.7058, 2.1950]],
[[ 0.6626, 0.8537, -2.7251, ..., -1.1831, -0.7083, -2.7717],
[ 0.4486, -1.1639, -2.1203, ..., -0.0901, 1.5883, 2.3840]]]],
grad_fn=<ViewBackward0>)
tensor([[[-0.8947, 0.0412, -1.2359, ..., 0.4410, -0.3965, 0.0106],
[-1.9196, 0.3326, 0.8482, ..., -1.5790, -1.1817, -1.0156],
[-2.1664, 0.3959, 0.7476, ..., -2.1767, -0.6488, 0.1889],
...,
[-1.6009, 0.4887, -0.4818, ..., -1.1268, 0.4111, 0.7892],
[-0.1528, 1.1728, -0.5164, ..., -0.4340, 0.1499, 1.6704],
[ 1.0253, 1.4222, -0.1805, ..., -0.6130, -0.5380, 1.6164]],
[[-0.6736, -0.0718, -1.1724, ..., 0.2001, -0.5481, 0.0232],
[-1.8698, -1.2184, 0.2913, ..., -1.1398, -1.3523, -0.7851],
[-1.3725, -0.8212, 0.1984, ..., -1.8218, -1.4800, -0.2956],
...,
[-0.5946, 0.5680, 0.8938, ..., -1.6653, 0.8218, 1.1902],
[ 1.2800, 1.9566, 0.2540, ..., -1.2290, 0.5257, 1.2667],
[ 1.3511, 1.3329, -0.0782, ..., -0.8454, -0.7400, 1.5966]]],
grad_fn=<ViewBackward0>)
tensor([[[-0.8947, 0.0412, -1.2359, ..., 0.4410, -0.3965, 0.0106],
[-1.9196, 0.3326, 0.8482, ..., -1.5790, -1.1817, -1.0156],
[-2.1664, 0.3959, 0.7476, ..., -2.1767, -0.6488, 0.1889],
...,
[-1.6009, 0.4887, -0.4818, ..., -1.1268, 0.4111, 0.7892],
[-0.1528, 1.1728, -0.5164, ..., -0.4340, 0.1499, 1.6704],
[ 1.0253, 1.4222, -0.1805, ..., -0.6130, -0.5380, 1.6164]],
[[-0.6736, -0.0718, -1.1724, ..., 0.2001, -0.5481, 0.0232],
[-1.8698, -1.2184, 0.2913, ..., -1.1398, -1.3523, -0.7851],
[-1.3725, -0.8212, 0.1984, ..., -1.8218, -1.4800, -0.2956],
...,
[-0.5946, 0.5680, 0.8938, ..., -1.6653, 0.8218, 1.1902],
[ 1.2800, 1.9566, 0.2540, ..., -1.2290, 0.5257, 1.2667],
[ 1.3511, 1.3329, -0.0782, ..., -0.8454, -0.7400, 1.5966]]],
grad_fn=<ViewBackward0>)
tensor([[[[-0.8947, 0.0412, -1.2359, ..., 1.5140, -1.9812, -2.5532],
[-0.2951, -1.6086, -0.6381, ..., 0.4410, -0.3965, 0.0106]],
[[-1.9196, 0.3326, 0.8482, ..., -2.3348, 1.3935, 1.1452],
[-0.5277, 0.1234, 0.7865, ..., -1.5790, -1.1817, -1.0156]],
[[-2.1664, 0.3959, 0.7476, ..., -2.0817, 0.2852, 0.8173],
[-0.8414, 0.5154, -0.4553, ..., -2.1767, -0.6488, 0.1889]],
...,
[[-1.6009, 0.4887, -0.4818, ..., -1.8165, 1.4764, 0.5091],
[-0.6869, 0.4007, -1.5818, ..., -1.1268, 0.4111, 0.7892]],
[[-0.1528, 1.1728, -0.5164, ..., -1.3611, 1.0621, 1.1810],
[-0.7595, -0.1699, -1.5305, ..., -0.4340, 0.1499, 1.6704]],
[[ 1.0253, 1.4222, -0.1805, ..., -0.6989, 0.4721, 2.6129],
[-1.2381, -0.4573, -1.7561, ..., -0.6130, -0.5380, 1.6164]]],
[[[-0.6736, -0.0718, -1.1724, ..., 1.4816, -1.7920, -2.5177],
[-0.3929, -1.5120, -0.5353, ..., 0.2001, -0.5481, 0.0232]],
[[-1.8698, -1.2184, 0.2913, ..., -1.5227, 1.9764, 0.6389],
[-0.4202, 0.4572, -1.0780, ..., -1.1398, -1.3523, -0.7851]],
[[-1.3725, -0.8212, 0.1984, ..., -2.1553, 1.7041, 0.7166],
[-1.0124, 0.9351, -0.0954, ..., -1.8218, -1.4800, -0.2956]],
...,
[[-0.5946, 0.5680, 0.8938, ..., -2.1904, 1.7986, 1.0902],
[-1.3820, 1.0268, -1.0041, ..., -1.6653, 0.8218, 1.1902]],
[[ 1.2800, 1.9566, 0.2540, ..., -1.6180, 1.6176, 2.5636],
[-2.0592, 0.7059, -1.3359, ..., -1.2290, 0.5257, 1.2667]],
[[ 1.3511, 1.3329, -0.0782, ..., -0.5836, 0.6491, 2.6554],
[-1.3686, -0.2348, -1.7438, ..., -0.8454, -0.7400, 1.5966]]]],
grad_fn=<ViewBackward0>)
tensor([[[ 3.8307e-01, 1.8836e-02, -1.0314e+00, ..., 6.4335e-01,
-3.1830e-01, -1.7296e+00],
[ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., 1.5973e+00,
1.1907e+00, -9.0454e-01],
[-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -5.3627e-01,
-1.0456e+00, -1.6301e+00],
...,
[-1.7466e-01, 3.0706e-02, 7.5225e-01, ..., -1.3217e+00,
-1.3415e+00, -3.8328e-01],
[ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -1.6600e-01,
-6.5011e-01, 2.1798e-02],
[-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., 1.3448e-03,
-1.6754e+00, 3.1771e-01]],
[[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., 6.4928e-01,
-3.2944e-01, -2.4185e+00],
[ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.8623e+00,
-1.1230e+00, 3.5013e-01],
[ 1.1642e-01, 1.2460e+00, 7.7942e-02, ..., 6.4975e-01,
-7.3862e-01, -2.1510e+00],
...,
[ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 1.7637e+00,
9.1331e-01, -1.7033e+00],
[ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., 9.8096e-01,
8.6736e-01, -2.3894e+00],
[ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., 6.2708e-01,
-1.3601e+00, -3.9984e-01]]], grad_fn=<ViewBackward0>)
tensor([[[ 3.8307e-01, 1.8836e-02, -1.0314e+00, ..., 6.4335e-01,
-3.1830e-01, -1.7296e+00],
[ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., 1.5973e+00,
1.1907e+00, -9.0454e-01],
[-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -5.3627e-01,
-1.0456e+00, -1.6301e+00],
...,
[-1.7466e-01, 3.0706e-02, 7.5225e-01, ..., -1.3217e+00,
-1.3415e+00, -3.8328e-01],
[ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -1.6600e-01,
-6.5011e-01, 2.1798e-02],
[-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., 1.3448e-03,
-1.6754e+00, 3.1771e-01]],
[[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., 6.4928e-01,
-3.2944e-01, -2.4185e+00],
[ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.8623e+00,
-1.1230e+00, 3.5013e-01],
[ 1.1642e-01, 1.2460e+00, 7.7942e-02, ..., 6.4975e-01,
-7.3862e-01, -2.1510e+00],
...,
[ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 1.7637e+00,
9.1331e-01, -1.7033e+00],
[ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., 9.8096e-01,
8.6736e-01, -2.3894e+00],
[ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., 6.2708e-01,
-1.3601e+00, -3.9984e-01]]], grad_fn=<ViewBackward0>)
tensor([[[[ 3.8307e-01, 1.8836e-02, -1.0314e+00, ..., -8.0955e-01,
-2.8557e-01, -2.3318e-01],
[-9.9661e-02, 3.1722e-01, -3.0517e-01, ..., 6.4335e-01,
-3.1830e-01, -1.7296e+00]],
[[ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., -7.9549e-01,
-6.9279e-01, -1.8082e-01],
[-9.9201e-01, 9.4938e-01, 4.4198e-02, ..., 1.5973e+00,
1.1907e+00, -9.0454e-01]],
[[-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -1.8905e+00,
1.6226e-01, -1.2953e+00],
[-7.0312e-01, -6.4926e-01, -5.0913e-01, ..., -5.3627e-01,
-1.0456e+00, -1.6301e+00]],
...,
[[-1.7466e-01, 3.0706e-02, 7.5225e-01, ..., -1.9281e+00,
1.1489e+00, -2.4530e-01],
[-7.6226e-02, 8.5814e-01, -1.5467e+00, ..., -1.3217e+00,
-1.3415e+00, -3.8328e-01]],
[[ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -2.4168e+00,
3.3385e-01, -6.2115e-02],
[-1.6390e+00, -1.6085e-01, -1.9118e+00, ..., -1.6600e-01,
-6.5011e-01, 2.1798e-02]],
[[-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., -1.0481e+00,
-2.7634e+00, 2.2741e-01],
[-1.4603e+00, 3.1239e-02, 3.8892e-01, ..., 1.3448e-03,
-1.6754e+00, 3.1771e-01]]],
[[[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., -7.8511e-01,
-3.2510e-01, -2.1300e-01],
[-3.9893e-01, 6.1469e-01, -3.9206e-01, ..., 6.4928e-01,
-3.2944e-01, -2.4185e+00]],
[[ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.6488e+00,
-8.6853e-01, -8.0782e-01],
[-2.1516e+00, -2.4247e-01, -8.1713e-01, ..., -1.8623e+00,
-1.1230e+00, 3.5013e-01]],
[[ 1.1642e-01, 1.2460e+00, 7.7942e-02, ..., -1.3121e+00,
-6.7044e-01, -1.1324e+00],
[ 4.3930e-01, 3.4082e-01, -1.2243e+00, ..., 6.4975e-01,
-7.3862e-01, -2.1510e+00]],
...,
[[ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 8.5980e-01,
3.2796e-01, -1.9442e+00],
[-8.9502e-01, 9.7357e-01, 8.5447e-01, ..., 1.7637e+00,
9.1331e-01, -1.7033e+00]],
[[ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., -1.5965e+00,
-3.9380e-01, -4.3585e-01],
[-2.2103e+00, 4.4127e-01, 1.1554e+00, ..., 9.8096e-01,
8.6736e-01, -2.3894e+00]],
[[ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., -1.1366e+00,
-2.8409e+00, 4.6710e-01],
[-1.9576e+00, 2.0176e-01, -4.1034e-02, ..., 6.2708e-01,
-1.3601e+00, -3.9984e-01]]]], grad_fn=<ViewBackward0>)
tensor([[[[ 3.7393e-01, 3.7078e-01, -7.5553e-01, ..., -9.4280e-01,
-6.5765e-01, -1.4711e-01],
[ 3.9396e-01, 4.7657e-02, -1.0277e+00, ..., -8.4926e-01,
-2.7815e-01, -2.6627e-01],
[ 9.9586e-01, 2.4414e-01, -8.3459e-01, ..., -8.5795e-01,
-4.3860e-01, -2.1582e-01],
...,
[ 1.4504e+00, -1.2077e-01, -4.8160e-01, ..., -1.6701e+00,
6.3807e-01, -8.3788e-02],
[ 3.8821e-01, 1.4150e-01, -1.5051e-01, ..., -1.5785e+00,
4.1589e-01, -1.8024e-01],
[ 1.0467e-01, 5.5196e-01, 1.0818e-01, ..., -2.0373e+00,
8.4396e-02, -6.9034e-02]],
[[-7.7230e-01, -1.1852e-01, -8.0275e-02, ..., -1.3790e-03,
-5.2249e-01, -4.2095e-01],
[-4.6209e-01, -1.2680e-01, -2.5711e-01, ..., 1.3235e-01,
-4.1385e-01, -1.3744e+00],
[-1.7985e-01, -2.6037e-01, 3.5678e-01, ..., 2.4736e-01,
-1.6626e-01, -7.0940e-01],
...,
[-6.2069e-01, 2.5298e-01, -7.8416e-01, ..., 8.6186e-02,
-6.9561e-01, -8.5675e-01],
[-9.1553e-01, 1.4648e-01, -5.5621e-02, ..., 1.8643e-01,
-1.0965e+00, -4.8097e-01],
[-1.0678e+00, 9.0884e-02, -4.5400e-02, ..., 1.5424e-01,
-1.1762e+00, -2.9385e-01]]],
[[[ 4.9445e-01, 9.6661e-03, -5.7118e-01, ..., -9.7640e-01,
-9.0948e-01, -1.3761e-01],
[ 4.7672e-01, -2.3672e-01, -8.9246e-01, ..., -8.9927e-01,
-3.9088e-01, -3.1635e-01],
[ 4.5120e-01, -5.8034e-02, -6.8736e-01, ..., -1.0352e+00,
-4.7996e-01, -4.5624e-01],
...,
[ 3.3362e-01, 1.5914e-03, -1.0108e+00, ..., -1.5704e+00,
-3.9079e-01, -2.1742e-01],
[ 1.0123e+00, -1.1147e+00, -1.3032e+00, ..., 3.4082e-01,
1.3298e-01, -1.5481e+00],
[ 1.4273e+00, -1.2951e+00, 5.7719e-01, ..., -1.2490e+00,
-4.1355e-01, -5.8558e-01]],
[[-1.0884e+00, 2.0838e-01, -5.0948e-01, ..., 2.2845e-01,
-8.2075e-01, -1.1496e+00],
[-3.4296e-01, 5.0833e-01, -8.6644e-01, ..., 3.3008e-01,
-5.5070e-01, -1.8440e+00],
[-3.2602e-01, 6.0671e-01, -7.5711e-01, ..., 4.7493e-01,
-4.6589e-01, -2.0454e+00],
...,
[-1.7399e+00, 4.6359e-01, 6.5213e-01, ..., 9.0025e-01,
3.1643e-01, -2.0784e+00],
[-1.8375e+00, 2.6226e-01, 3.7727e-02, ..., 6.6286e-01,
-1.0588e+00, -7.6781e-01],
[-1.6402e+00, 3.1250e-01, -6.6090e-04, ..., 6.6493e-01,
-9.1110e-01, -1.0382e+00]]]],
grad_fn=<ScaledDotProductFlashAttentionForCpuBackward0>)
tensor([[[[ 3.7393e-01, 3.7078e-01, -7.5553e-01, ..., -9.4280e-01,
-6.5765e-01, -1.4711e-01],
[-7.7230e-01, -1.1852e-01, -8.0275e-02, ..., -1.3790e-03,
-5.2249e-01, -4.2095e-01]],
[[ 3.9396e-01, 4.7657e-02, -1.0277e+00, ..., -8.4926e-01,
-2.7815e-01, -2.6627e-01],
[-4.6209e-01, -1.2680e-01, -2.5711e-01, ..., 1.3235e-01,
-4.1385e-01, -1.3744e+00]],
[[ 9.9586e-01, 2.4414e-01, -8.3459e-01, ..., -8.5795e-01,
-4.3860e-01, -2.1582e-01],
[-1.7985e-01, -2.6037e-01, 3.5678e-01, ..., 2.4736e-01,
-1.6626e-01, -7.0940e-01]],
...,
[[ 1.4504e+00, -1.2077e-01, -4.8160e-01, ..., -1.6701e+00,
6.3807e-01, -8.3788e-02],
[-6.2069e-01, 2.5298e-01, -7.8416e-01, ..., 8.6186e-02,
-6.9561e-01, -8.5675e-01]],
[[ 3.8821e-01, 1.4150e-01, -1.5051e-01, ..., -1.5785e+00,
4.1589e-01, -1.8024e-01],
[-9.1553e-01, 1.4648e-01, -5.5621e-02, ..., 1.8643e-01,
-1.0965e+00, -4.8097e-01]],
[[ 1.0467e-01, 5.5196e-01, 1.0818e-01, ..., -2.0373e+00,
8.4396e-02, -6.9034e-02],
[-1.0678e+00, 9.0884e-02, -4.5400e-02, ..., 1.5424e-01,
-1.1762e+00, -2.9385e-01]]],
[[[ 4.9445e-01, 9.6661e-03, -5.7118e-01, ..., -9.7640e-01,
-9.0948e-01, -1.3761e-01],
[-1.0884e+00, 2.0838e-01, -5.0948e-01, ..., 2.2845e-01,
-8.2075e-01, -1.1496e+00]],
[[ 4.7672e-01, -2.3672e-01, -8.9246e-01, ..., -8.9927e-01,
-3.9088e-01, -3.1635e-01],
[-3.4296e-01, 5.0833e-01, -8.6644e-01, ..., 3.3008e-01,
-5.5070e-01, -1.8440e+00]],
[[ 4.5120e-01, -5.8034e-02, -6.8736e-01, ..., -1.0352e+00,
-4.7996e-01, -4.5624e-01],
[-3.2602e-01, 6.0671e-01, -7.5711e-01, ..., 4.7493e-01,
-4.6589e-01, -2.0454e+00]],
...,
[[ 3.3362e-01, 1.5914e-03, -1.0108e+00, ..., -1.5704e+00,
-3.9079e-01, -2.1742e-01],
[-1.7399e+00, 4.6359e-01, 6.5213e-01, ..., 9.0025e-01,
3.1643e-01, -2.0784e+00]],
[[ 1.0123e+00, -1.1147e+00, -1.3032e+00, ..., 3.4082e-01,
1.3298e-01, -1.5481e+00],
[-1.8375e+00, 2.6226e-01, 3.7727e-02, ..., 6.6286e-01,
-1.0588e+00, -7.6781e-01]],
[[ 1.4273e+00, -1.2951e+00, 5.7719e-01, ..., -1.2490e+00,
-4.1355e-01, -5.8558e-01],
[-1.6402e+00, 3.1250e-01, -6.6090e-04, ..., 6.6493e-01,
-9.1110e-01, -1.0382e+00]]]], grad_fn=<TransposeBackward0>)
tensor([[-0.0073, -0.0693],
[ 0.0201, -0.0920]], grad_fn=<AddmmBackward0>)
tensor([1, 0])
Full Supervised Finetuning (SFT)#
Before training the model, let’s inspect how many trainable parameters there are. If you’re familiar with Keras, you might have used the model.summary()
API before, but it’s not as easy to do the same in Pytorch - luckily, Mase has a module-level pass with this functionality.
from chop.passes.module import report_trainable_parameters_analysis_pass
_, _ = report_trainable_parameters_analysis_pass(mg.model)
+-------------------------------------------------+------------------------+
| Submodule | Trainable Parameters |
+=================================================+========================+
| bert | 4385920 |
+-------------------------------------------------+------------------------+
| bert.embeddings | 3972864 |
+-------------------------------------------------+------------------------+
| bert.embeddings.word_embeddings | 3906816 |
+-------------------------------------------------+------------------------+
| bert.embeddings.token_type_embeddings | 256 |
+-------------------------------------------------+------------------------+
| bert.embeddings.position_embeddings | 65536 |
+-------------------------------------------------+------------------------+
| bert.embeddings.LayerNorm | 256 |
+-------------------------------------------------+------------------------+
| bert.embeddings.dropout | 0 |
+-------------------------------------------------+------------------------+
| bert.encoder | 396544 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer | 396544 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0 | 198272 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention | 66304 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self | 49536 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query | 16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key | 16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value | 16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output | 16768 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense | 16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dropout | 0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.LayerNorm | 256 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate | 66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense | 66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output | 65920 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense | 65664 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dropout | 0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.LayerNorm | 256 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1 | 198272 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention | 66304 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self | 49536 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query | 16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key | 16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value | 16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output | 16768 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense | 16512 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dropout | 0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.LayerNorm | 256 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate | 66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense | 66048 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output | 65920 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense | 65664 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dropout | 0 |
+-------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.LayerNorm | 256 |
+-------------------------------------------------+------------------------+
| bert.pooler | 16512 |
+-------------------------------------------------+------------------------+
| bert.pooler.dense | 16512 |
+-------------------------------------------------+------------------------+
| bert.pooler.activation | 0 |
+-------------------------------------------------+------------------------+
| dropout | 0 |
+-------------------------------------------------+------------------------+
| classifier | 258 |
+-------------------------------------------------+------------------------+
| crossentropyloss_0 | 0 |
+-------------------------------------------------+------------------------+
Total Trainable Parameters: 14480258
From this, we can see the majority of the trainable parameters are in the Embedding
layer. We don’t need to train this, so we freeze those parameters in the cell below.
for param in mg.model.bert.embeddings.parameters():
param.requires_grad = False
To train the model, we rely on the Trainer
class from the transformers
library, which makes it easy to set up a training loop with any hardware configuration. The get_trainer
utility in Mase handles assigning the training arguments to the Trainer
class for common use cases, such as in this tutorial.
from chop.tools import get_trainer
trainer = get_trainer(
model=mg.model,
tokenized_dataset=dataset,
tokenizer=tokenizer,
evaluate_metric="accuracy",
)
Before running any fine tuning, let’s see how the model performs out of the box. Without any fine-tuning, we can see the model just performs a random guess - there are two labels in the dataset, so this corresponds to an accuracy of around 50%.
# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
100%|██████████| 3125/3125 [01:07<00:00, 46.59it/s]
Evaluation accuracy: 0.52164
Now, run the cell below to execute a single training epoch with the current setup.
trainer.train()
Let’s see how much accuracy we get after a single training epoch of full finetuning.
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
We can now export the SFT version of the model to be used in later tutorials.
from pathlib import Path
mg.export(f"{Path.home()}/tutorial_2_sft")
Parameter Efficient Finetuning (PEFT) with LoRA#
An alternative to full fine-tuning is Parameter Efficient Fine Tuning (PEFT), which uses a small number of trainable parameters to achieve similar performance. LoRA was proposed by a research team at Microsoft in 2021, as an efficient technique for PEFT.
Consider the standard equation of a linear layer:
The LoRA method involves replacing this with the following, where A and B are low-rank matrices. We freeze the \(W\) parameters, and only allow the optimizer to train the parameters in \(A\) and \(B\).
This enables us to achieve accuracies comparable to full fine tuning, while only training a fraction of the parameters. See the paper for more details. We can inject the LoRA adapter into the existing model using the insert_lora_adapter_transform_pass
pass in Mase, as follows.
mg, _ = passes.insert_lora_adapter_transform_pass(
mg,
pass_args={
"rank": 6,
"alpha": 1.0,
"dropout": 0.5,
},
)
INFO Replaced node: bert_encoder_layer_0_attention_self_query, target: bert.encoder.layer.0.attention.self.query with LoRALinear module.
INFO Replaced node: bert_encoder_layer_0_attention_self_key, target: bert.encoder.layer.0.attention.self.key with LoRALinear module.
INFO Replaced node: bert_encoder_layer_0_attention_self_value, target: bert.encoder.layer.0.attention.self.value with LoRALinear module.
INFO Replaced node: bert_encoder_layer_0_attention_output_dense, target: bert.encoder.layer.0.attention.output.dense with LoRALinear module.
INFO Replaced node: bert_encoder_layer_0_intermediate_dense, target: bert.encoder.layer.0.intermediate.dense with LoRALinear module.
INFO Replaced node: bert_encoder_layer_0_output_dense, target: bert.encoder.layer.0.output.dense with LoRALinear module.
INFO Replaced node: bert_encoder_layer_1_attention_self_query, target: bert.encoder.layer.1.attention.self.query with LoRALinear module.
INFO Replaced node: bert_encoder_layer_1_attention_self_key, target: bert.encoder.layer.1.attention.self.key with LoRALinear module.
INFO Replaced node: bert_encoder_layer_1_attention_self_value, target: bert.encoder.layer.1.attention.self.value with LoRALinear module.
INFO Replaced node: bert_encoder_layer_1_attention_output_dense, target: bert.encoder.layer.1.attention.output.dense with LoRALinear module.
INFO Replaced node: bert_encoder_layer_1_intermediate_dense, target: bert.encoder.layer.1.intermediate.dense with LoRALinear module.
INFO Replaced node: bert_encoder_layer_1_output_dense, target: bert.encoder.layer.1.output.dense with LoRALinear module.
INFO Replaced node: bert_pooler_dense, target: bert.pooler.dense with LoRALinear module.
INFO Replaced node: classifier, target: classifier with LoRALinear module.
Similar to before, let’s report the number of trainable parameters.
_, _ = report_trainable_parameters_analysis_pass(mg.model)
+-----------------------------------------------------+------------------------+
| Submodule | Trainable Parameters |
+=====================================================+========================+
| bert | 439808 |
+-----------------------------------------------------+------------------------+
| bert.embeddings | 0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.word_embeddings | 0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.token_type_embeddings | 0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.position_embeddings | 0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.LayerNorm | 0 |
+-----------------------------------------------------+------------------------+
| bert.embeddings.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder | 421888 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer | 421888 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0 | 210944 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention | 71936 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self | 53760 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query | 17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.linear | 16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.lora_b | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.query.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key | 17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.linear | 16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.lora_b | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.key.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value | 17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.linear | 16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.lora_b | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.self.value.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output | 18176 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense | 17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.linear | 16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.lora_b | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dense.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.attention.output.LayerNorm | 256 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate | 69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense | 69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.linear | 65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.lora_b | 3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.intermediate.dense.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output | 69632 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense | 69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.linear | 65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.lora_a | 3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.lora_b | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dense.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.0.output.LayerNorm | 256 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1 | 210944 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention | 71936 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self | 53760 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query | 17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.linear | 16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.lora_b | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.query.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key | 17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.linear | 16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.lora_b | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.key.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value | 17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.linear | 16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.lora_b | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.self.value.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output | 18176 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense | 17920 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.linear | 16384 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.lora_b | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dense.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.attention.output.LayerNorm | 256 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate | 69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense | 69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.linear | 65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.lora_b | 3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.intermediate.dense.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output | 69632 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense | 69376 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.linear | 65536 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.lora_a | 3072 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.lora_b | 768 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dense.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.encoder.layer.1.output.LayerNorm | 256 |
+-----------------------------------------------------+------------------------+
| bert.pooler | 17920 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense | 17920 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.linear | 16384 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.lora_b | 768 |
+-----------------------------------------------------+------------------------+
| bert.pooler.dense.dropout | 0 |
+-----------------------------------------------------+------------------------+
| bert.pooler.activation | 0 |
+-----------------------------------------------------+------------------------+
| dropout | 0 |
+-----------------------------------------------------+------------------------+
| classifier | 1036 |
+-----------------------------------------------------+------------------------+
| classifier.linear | 256 |
+-----------------------------------------------------+------------------------+
| classifier.lora_a | 768 |
+-----------------------------------------------------+------------------------+
| classifier.lora_b | 12 |
+-----------------------------------------------------+------------------------+
| classifier.dropout | 0 |
+-----------------------------------------------------+------------------------+
| crossentropyloss_0 | 0 |
+-----------------------------------------------------+------------------------+
Total Trainable Parameters: 3169816
In this case, LoRA reduces the number of trainable parameters by \(4.5\times\)! We’ll run a few more training epochs and evaluate the resulting accuracy.
trainer = get_trainer(
model=mg.model,
tokenized_dataset=dataset,
tokenizer=tokenizer,
evaluate_metric="accuracy",
num_train_epochs=1,
)
trainer.train()
# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
16%|█▌ | 500/3125 [00:40<03:29, 12.52it/s]
{'loss': 0.6402, 'grad_norm': 2.730722665786743, 'learning_rate': 4.2e-05, 'epoch': 0.16}
32%|███▏ | 1000/3125 [01:13<03:24, 10.41it/s]
{'loss': 0.5216, 'grad_norm': 5.168735504150391, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.32}
48%|████▊ | 1501/3125 [01:44<01:20, 20.05it/s]
{'loss': 0.4751, 'grad_norm': 13.205789566040039, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.48}
64%|██████▍ | 2001/3125 [02:11<01:44, 10.79it/s]
{'loss': 0.4376, 'grad_norm': 14.12922191619873, 'learning_rate': 1.8e-05, 'epoch': 0.64}
80%|████████ | 2500/3125 [02:37<00:24, 25.37it/s]
{'loss': 0.4233, 'grad_norm': 4.498361587524414, 'learning_rate': 1e-05, 'epoch': 0.8}
96%|█████████▌| 3002/3125 [03:01<00:10, 11.99it/s]
{'loss': 0.4164, 'grad_norm': 8.473535537719727, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.96}
100%|██████████| 3125/3125 [03:07<00:00, 16.69it/s]
{'train_runtime': 187.2829, 'train_samples_per_second': 133.488, 'train_steps_per_second': 16.686, 'train_loss': 0.48366960693359373, 'epoch': 1.0}
100%|██████████| 3125/3125 [01:08<00:00, 45.64it/s]
Evaluation accuracy: 0.8218
After training is finished, we can run the fuse_lora_weights_transform_pass
pass to optimize the model for inference. This pass replaces each LoRALinear
instance with an nn.Linear
module, where the \(AB\) product added to the original weights matrix. This incurs less kernel invocations when deploying the model, which reduces inference runtime.
mg, _ = passes.fuse_lora_weights_transform_pass(mg)
eval_results = trainer.evaluate()
100%|██████████| 3125/3125 [01:26<00:00, 36.31it/s]
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
Evaluation accuracy: 0.8218
Conclusion#
Finally, export the finetuned model to be used in future tutorials.
from pathlib import Path
mg.export(f"{Path.home()}/tutorial_2_lora")
INFO Exporting MaseGraph to /Users/yz10513/tutorial_2_lora.pt, /Users/yz10513/tutorial_2_lora.mz
INFO Exporting GraphModule to /Users/yz10513/tutorial_2_lora.pt
INFO Exporting MaseMetadata to /Users/yz10513/tutorial_2_lora.mz
WARNING Failed to pickle call_function node: finfo
WARNING cannot pickle 'torch.finfo' object
WARNING Failed to pickle call_function node: getattr_2
WARNING cannot pickle 'torch.finfo' object